diff --git a/external/llite/library/__init__.py b/external/llite/library/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/external/llite/library/attention_processors.py b/external/llite/library/attention_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..310c2cb1c63955f8f03296c54fd47c21f1a981c9
--- /dev/null
+++ b/external/llite/library/attention_processors.py
@@ -0,0 +1,227 @@
+import math
+from typing import Any
+from einops import rearrange
+import torch
+from diffusers.models.attention_processor import Attention
+
+
+# flash attention forwards and backwards
+
+# https://arxiv.org/abs/2205.14135
+
+EPSILON = 1e-6
+
+
+class FlashAttentionFunction(torch.autograd.function.Function):
+    @staticmethod
+    @torch.no_grad()
+    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
+        """Algorithm 2 in the paper"""
+
+        device = q.device
+        dtype = q.dtype
+        max_neg_value = -torch.finfo(q.dtype).max
+        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+        o = torch.zeros_like(q)
+        all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
+        all_row_maxes = torch.full(
+            (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
+        )
+
+        scale = q.shape[-1] ** -0.5
+
+        if mask is None:
+            mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
+        else:
+            mask = rearrange(mask, "b n -> b 1 1 n")
+            mask = mask.split(q_bucket_size, dim=-1)
+
+        row_splits = zip(
+            q.split(q_bucket_size, dim=-2),
+            o.split(q_bucket_size, dim=-2),
+            mask,
+            all_row_sums.split(q_bucket_size, dim=-2),
+            all_row_maxes.split(q_bucket_size, dim=-2),
+        )
+
+        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
+            q_start_index = ind * q_bucket_size - qk_len_diff
+
+            col_splits = zip(
+                k.split(k_bucket_size, dim=-2),
+                v.split(k_bucket_size, dim=-2),
+            )
+
+            for k_ind, (kc, vc) in enumerate(col_splits):
+                k_start_index = k_ind * k_bucket_size
+
+                attn_weights = (
+                    torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
+                )
+
+                if row_mask is not None:
+                    attn_weights.masked_fill_(~row_mask, max_neg_value)
+
+                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+                    causal_mask = torch.ones(
+                        (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
+                    ).triu(q_start_index - k_start_index + 1)
+                    attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+                block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
+                attn_weights -= block_row_maxes
+                exp_weights = torch.exp(attn_weights)
+
+                if row_mask is not None:
+                    exp_weights.masked_fill_(~row_mask, 0.0)
+
+                block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
+                    min=EPSILON
+                )
+
+                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
+
+                exp_values = torch.einsum(
+                    "... i j, ... j d -> ... i d", exp_weights, vc
+                )
+
+                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
+                exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
+
+                new_row_sums = (
+                    exp_row_max_diff * row_sums
+                    + exp_block_row_max_diff * block_row_sums
+                )
+
+                oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
+                    (exp_block_row_max_diff / new_row_sums) * exp_values
+                )
+
+                row_maxes.copy_(new_row_maxes)
+                row_sums.copy_(new_row_sums)
+
+        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
+        ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
+
+        return o
+
+    @staticmethod
+    @torch.no_grad()
+    def backward(ctx, do):
+        """Algorithm 4 in the paper"""
+
+        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
+        q, k, v, o, l, m = ctx.saved_tensors
+
+        device = q.device
+
+        max_neg_value = -torch.finfo(q.dtype).max
+        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+        dq = torch.zeros_like(q)
+        dk = torch.zeros_like(k)
+        dv = torch.zeros_like(v)
+
+        row_splits = zip(
+            q.split(q_bucket_size, dim=-2),
+            o.split(q_bucket_size, dim=-2),
+            do.split(q_bucket_size, dim=-2),
+            mask,
+            l.split(q_bucket_size, dim=-2),
+            m.split(q_bucket_size, dim=-2),
+            dq.split(q_bucket_size, dim=-2),
+        )
+
+        for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
+            q_start_index = ind * q_bucket_size - qk_len_diff
+
+            col_splits = zip(
+                k.split(k_bucket_size, dim=-2),
+                v.split(k_bucket_size, dim=-2),
+                dk.split(k_bucket_size, dim=-2),
+                dv.split(k_bucket_size, dim=-2),
+            )
+
+            for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
+                k_start_index = k_ind * k_bucket_size
+
+                attn_weights = (
+                    torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
+                )
+
+                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+                    causal_mask = torch.ones(
+                        (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
+                    ).triu(q_start_index - k_start_index + 1)
+                    attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+                exp_attn_weights = torch.exp(attn_weights - mc)
+
+                if row_mask is not None:
+                    exp_attn_weights.masked_fill_(~row_mask, 0.0)
+
+                p = exp_attn_weights / lc
+
+                dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
+                dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
+
+                D = (doc * oc).sum(dim=-1, keepdims=True)
+                ds = p * scale * (dp - D)
+
+                dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
+                dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
+
+                dqc.add_(dq_chunk)
+                dkc.add_(dk_chunk)
+                dvc.add_(dv_chunk)
+
+        return dq, dk, dv, None, None, None, None
+
+
+class FlashAttnProcessor:
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states,
+        encoder_hidden_states=None,
+        attention_mask=None,
+    ) -> Any:
+        q_bucket_size = 512
+        k_bucket_size = 1024
+
+        h = attn.heads
+        q = attn.to_q(hidden_states)
+
+        encoder_hidden_states = (
+            encoder_hidden_states
+            if encoder_hidden_states is not None
+            else hidden_states
+        )
+        encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
+
+        if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
+            context_k, context_v = attn.hypernetwork.forward(
+                hidden_states, encoder_hidden_states
+            )
+            context_k = context_k.to(hidden_states.dtype)
+            context_v = context_v.to(hidden_states.dtype)
+        else:
+            context_k = encoder_hidden_states
+            context_v = encoder_hidden_states
+
+        k = attn.to_k(context_k)
+        v = attn.to_v(context_v)
+        del encoder_hidden_states, hidden_states
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+        out = FlashAttentionFunction.apply(
+            q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
+        )
+
+        out = rearrange(out, "b h n d -> b n (h d)")
+
+        out = attn.to_out[0](out)
+        out = attn.to_out[1](out)
+        return out
diff --git a/external/llite/library/config_util.py b/external/llite/library/config_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..47868f3baf8f3a34c0026a4246bafafc3083b0d1
--- /dev/null
+++ b/external/llite/library/config_util.py
@@ -0,0 +1,621 @@
+import argparse
+from dataclasses import (
+  asdict,
+  dataclass,
+)
+import functools
+import random
+from textwrap import dedent, indent
+import json
+from pathlib import Path
+# from toolz import curry
+from typing import (
+  List,
+  Optional,
+  Sequence,
+  Tuple,
+  Union,
+)
+
+import toml
+import voluptuous
+from voluptuous import (
+  Any,
+  ExactSequence,
+  MultipleInvalid,
+  Object,
+  Required,
+  Schema,
+)
+from transformers import CLIPTokenizer
+
+from . import train_util
+from .train_util import (
+  DreamBoothSubset,
+  FineTuningSubset,
+  ControlNetSubset,
+  DreamBoothDataset,
+  FineTuningDataset,
+  ControlNetDataset,
+  DatasetGroup,
+)
+
+
+def add_config_arguments(parser: argparse.ArgumentParser):
+  parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
+
+# TODO: inherit Params class in Subset, Dataset
+
+@dataclass
+class BaseSubsetParams:
+  image_dir: Optional[str] = None
+  num_repeats: int = 1
+  shuffle_caption: bool = False
+  caption_separator: str = ',',
+  keep_tokens: int = 0
+  keep_tokens_separator: str = None,
+  color_aug: bool = False
+  flip_aug: bool = False
+  face_crop_aug_range: Optional[Tuple[float, float]] = None
+  random_crop: bool = False
+  caption_prefix: Optional[str] = None
+  caption_suffix: Optional[str] = None
+  caption_dropout_rate: float = 0.0
+  caption_dropout_every_n_epochs: int = 0
+  caption_tag_dropout_rate: float = 0.0
+  token_warmup_min: int = 1
+  token_warmup_step: float = 0
+
+@dataclass
+class DreamBoothSubsetParams(BaseSubsetParams):
+  is_reg: bool = False
+  class_tokens: Optional[str] = None
+  caption_extension: str = ".caption"
+
+@dataclass
+class FineTuningSubsetParams(BaseSubsetParams):
+  metadata_file: Optional[str] = None
+
+@dataclass
+class ControlNetSubsetParams(BaseSubsetParams):
+  conditioning_data_dir: str = None
+  caption_extension: str = ".caption"
+
+@dataclass
+class BaseDatasetParams:
+  tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
+  max_token_length: int = None
+  resolution: Optional[Tuple[int, int]] = None
+  debug_dataset: bool = False
+
+@dataclass
+class DreamBoothDatasetParams(BaseDatasetParams):
+  batch_size: int = 1
+  enable_bucket: bool = False
+  min_bucket_reso: int = 256
+  max_bucket_reso: int = 1024
+  bucket_reso_steps: int = 64
+  bucket_no_upscale: bool = False
+  prior_loss_weight: float = 1.0
+
+@dataclass
+class FineTuningDatasetParams(BaseDatasetParams):
+  batch_size: int = 1
+  enable_bucket: bool = False
+  min_bucket_reso: int = 256
+  max_bucket_reso: int = 1024
+  bucket_reso_steps: int = 64
+  bucket_no_upscale: bool = False
+
+@dataclass
+class ControlNetDatasetParams(BaseDatasetParams):
+  batch_size: int = 1
+  enable_bucket: bool = False
+  min_bucket_reso: int = 256
+  max_bucket_reso: int = 1024
+  bucket_reso_steps: int = 64
+  bucket_no_upscale: bool = False
+
+@dataclass
+class SubsetBlueprint:
+  params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
+
+@dataclass
+class DatasetBlueprint:
+  is_dreambooth: bool
+  is_controlnet: bool
+  params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
+  subsets: Sequence[SubsetBlueprint]
+
+@dataclass
+class DatasetGroupBlueprint:
+  datasets: Sequence[DatasetBlueprint]
+@dataclass
+class Blueprint:
+  dataset_group: DatasetGroupBlueprint
+
+
+class ConfigSanitizer:
+  # @curry
+  @staticmethod
+  def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
+    Schema(ExactSequence([klass, klass]))(value)
+    return tuple(value)
+
+  # @curry
+  @staticmethod
+  def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
+    Schema(Any(klass, ExactSequence([klass, klass])))(value)
+    try:
+      Schema(klass)(value)
+      return (value, value)
+    except:
+      return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
+
+  # subset schema
+  SUBSET_ASCENDABLE_SCHEMA = {
+    "color_aug": bool,
+    "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
+    "flip_aug": bool,
+    "num_repeats": int,
+    "random_crop": bool,
+    "shuffle_caption": bool,
+    "keep_tokens": int,
+    "keep_tokens_separator": str,
+    "token_warmup_min": int,
+    "token_warmup_step": Any(float,int),
+    "caption_prefix": str,
+    "caption_suffix": str,
+  }
+  # DO means DropOut
+  DO_SUBSET_ASCENDABLE_SCHEMA = {
+    "caption_dropout_every_n_epochs": int,
+    "caption_dropout_rate": Any(float, int),
+    "caption_tag_dropout_rate": Any(float, int),
+  }
+  # DB means DreamBooth
+  DB_SUBSET_ASCENDABLE_SCHEMA = {
+    "caption_extension": str,
+    "class_tokens": str,
+  }
+  DB_SUBSET_DISTINCT_SCHEMA = {
+    Required("image_dir"): str,
+    "is_reg": bool,
+  }
+  # FT means FineTuning
+  FT_SUBSET_DISTINCT_SCHEMA = {
+    Required("metadata_file"): str,
+    "image_dir": str,
+  }
+  CN_SUBSET_ASCENDABLE_SCHEMA = {
+    "caption_extension": str,
+  }
+  CN_SUBSET_DISTINCT_SCHEMA = {
+    Required("image_dir"): str,
+    Required("conditioning_data_dir"): str,
+  }
+
+  # datasets schema
+  DATASET_ASCENDABLE_SCHEMA = {
+    "batch_size": int,
+    "bucket_no_upscale": bool,
+    "bucket_reso_steps": int,
+    "enable_bucket": bool,
+    "max_bucket_reso": int,
+    "min_bucket_reso": int,
+    "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
+  }
+
+  # options handled by argparse but not handled by user config
+  ARGPARSE_SPECIFIC_SCHEMA = {
+    "debug_dataset": bool,
+    "max_token_length": Any(None, int),
+    "prior_loss_weight": Any(float, int),
+  }
+  # for handling default None value of argparse
+  ARGPARSE_NULLABLE_OPTNAMES = [
+    "face_crop_aug_range",
+    "resolution",
+  ]
+  # prepare map because option name may differ among argparse and user config
+  ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
+    "train_batch_size": "batch_size",
+    "dataset_repeats": "num_repeats",
+  }
+
+  def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
+    assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
+
+    self.db_subset_schema = self.__merge_dict(
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.DB_SUBSET_DISTINCT_SCHEMA,
+      self.DB_SUBSET_ASCENDABLE_SCHEMA,
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+    )
+
+    self.ft_subset_schema = self.__merge_dict(
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.FT_SUBSET_DISTINCT_SCHEMA,
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+    )
+
+    self.cn_subset_schema = self.__merge_dict(
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.CN_SUBSET_DISTINCT_SCHEMA,
+      self.CN_SUBSET_ASCENDABLE_SCHEMA,
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+    )
+
+    self.db_dataset_schema = self.__merge_dict(
+      self.DATASET_ASCENDABLE_SCHEMA,
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.DB_SUBSET_ASCENDABLE_SCHEMA,
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+      {"subsets": [self.db_subset_schema]},
+    )
+
+    self.ft_dataset_schema = self.__merge_dict(
+      self.DATASET_ASCENDABLE_SCHEMA,
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+      {"subsets": [self.ft_subset_schema]},
+    )
+
+    self.cn_dataset_schema = self.__merge_dict(
+      self.DATASET_ASCENDABLE_SCHEMA,
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.CN_SUBSET_ASCENDABLE_SCHEMA,
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+      {"subsets": [self.cn_subset_schema]},
+    )
+
+    if support_dreambooth and support_finetuning:
+      def validate_flex_dataset(dataset_config: dict):
+        subsets_config = dataset_config.get("subsets", [])
+
+        if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
+          return Schema(self.cn_dataset_schema)(dataset_config)
+        # check dataset meets FT style
+        # NOTE: all FT subsets should have "metadata_file"
+        elif all(["metadata_file" in subset for subset in subsets_config]):
+          return Schema(self.ft_dataset_schema)(dataset_config)
+        # check dataset meets DB style
+        # NOTE: all DB subsets should have no "metadata_file"
+        elif all(["metadata_file" not in subset for subset in subsets_config]):
+          return Schema(self.db_dataset_schema)(dataset_config)
+        else:
+          raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
+
+      self.dataset_schema = validate_flex_dataset
+    elif support_dreambooth:
+      self.dataset_schema = self.db_dataset_schema
+    elif support_finetuning:
+      self.dataset_schema = self.ft_dataset_schema
+    elif support_controlnet:
+      self.dataset_schema = self.cn_dataset_schema
+
+    self.general_schema = self.__merge_dict(
+      self.DATASET_ASCENDABLE_SCHEMA,
+      self.SUBSET_ASCENDABLE_SCHEMA,
+      self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
+      self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
+      self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
+    )
+
+    self.user_config_validator = Schema({
+      "general": self.general_schema,
+      "datasets": [self.dataset_schema],
+    })
+
+    self.argparse_schema = self.__merge_dict(
+      self.general_schema,
+      self.ARGPARSE_SPECIFIC_SCHEMA,
+      {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
+      {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
+    )
+
+    self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
+
+  def sanitize_user_config(self, user_config: dict) -> dict:
+    try:
+      return self.user_config_validator(user_config)
+    except MultipleInvalid:
+      # TODO: エラー発生時のメッセージをわかりやすくする
+      print("Invalid user config / ユーザ設定の形式が正しくないようです")
+      raise
+
+  # NOTE: In nature, argument parser result is not needed to be sanitize
+  #   However this will help us to detect program bug
+  def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
+    try:
+      return self.argparse_config_validator(argparse_namespace)
+    except MultipleInvalid:
+      # XXX: this should be a bug
+      print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
+      raise
+
+  # NOTE: value would be overwritten by latter dict if there is already the same key
+  @staticmethod
+  def __merge_dict(*dict_list: dict) -> dict:
+    merged = {}
+    for schema in dict_list:
+      # merged |= schema
+      for k, v in schema.items():
+        merged[k] = v
+    return merged
+
+
+class BlueprintGenerator:
+  BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
+  }
+
+  def __init__(self, sanitizer: ConfigSanitizer):
+    self.sanitizer = sanitizer
+
+  # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
+  def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
+    sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
+    sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
+
+    # convert argparse namespace to dict like config
+    # NOTE: it is ok to have extra entries in dict
+    optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
+    argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
+
+    general_config = sanitized_user_config.get("general", {})
+
+    dataset_blueprints = []
+    for dataset_config in sanitized_user_config.get("datasets", []):
+      # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
+      subsets = dataset_config.get("subsets", [])
+      is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
+      is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
+      if is_controlnet:
+        subset_params_klass = ControlNetSubsetParams
+        dataset_params_klass = ControlNetDatasetParams
+      elif is_dreambooth:
+        subset_params_klass = DreamBoothSubsetParams
+        dataset_params_klass = DreamBoothDatasetParams
+      else:
+        subset_params_klass = FineTuningSubsetParams
+        dataset_params_klass = FineTuningDatasetParams
+
+      subset_blueprints = []
+      for subset_config in subsets:
+        params = self.generate_params_by_fallbacks(subset_params_klass,
+                                                   [subset_config, dataset_config, general_config, argparse_config, runtime_params])
+        subset_blueprints.append(SubsetBlueprint(params))
+
+      params = self.generate_params_by_fallbacks(dataset_params_klass,
+                                                 [dataset_config, general_config, argparse_config, runtime_params])
+      dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
+
+    dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
+
+    return Blueprint(dataset_group_blueprint)
+
+  @staticmethod
+  def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
+    name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
+    search_value = BlueprintGenerator.search_value
+    default_params = asdict(param_klass())
+    param_names = default_params.keys()
+
+    params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
+
+    return param_klass(**params)
+
+  @staticmethod
+  def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
+    for cand in fallbacks:
+      value = cand.get(key)
+      if value is not None:
+        return value
+
+    return default_value
+
+
+def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
+  datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
+
+  for dataset_blueprint in dataset_group_blueprint.datasets:
+    if dataset_blueprint.is_controlnet:
+      subset_klass = ControlNetSubset
+      dataset_klass = ControlNetDataset
+    elif dataset_blueprint.is_dreambooth:
+      subset_klass = DreamBoothSubset
+      dataset_klass = DreamBoothDataset
+    else:
+      subset_klass = FineTuningSubset
+      dataset_klass = FineTuningDataset
+
+    subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
+    dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
+    datasets.append(dataset)
+
+  # print info
+  info = ""
+  for i, dataset in enumerate(datasets):
+    is_dreambooth = isinstance(dataset, DreamBoothDataset)
+    is_controlnet = isinstance(dataset, ControlNetDataset)
+    info += dedent(f"""\
+      [Dataset {i}]
+        batch_size: {dataset.batch_size}
+        resolution: {(dataset.width, dataset.height)}
+        enable_bucket: {dataset.enable_bucket}
+    """)
+
+    if dataset.enable_bucket:
+      info += indent(dedent(f"""\
+        min_bucket_reso: {dataset.min_bucket_reso}
+        max_bucket_reso: {dataset.max_bucket_reso}
+        bucket_reso_steps: {dataset.bucket_reso_steps}
+        bucket_no_upscale: {dataset.bucket_no_upscale}
+      \n"""), "  ")
+    else:
+      info += "\n"
+
+    for j, subset in enumerate(dataset.subsets):
+      info += indent(dedent(f"""\
+        [Subset {j} of Dataset {i}]
+          image_dir: "{subset.image_dir}"
+          image_count: {subset.img_count}
+          num_repeats: {subset.num_repeats}
+          shuffle_caption: {subset.shuffle_caption}
+          keep_tokens: {subset.keep_tokens}
+          keep_tokens_separator: {subset.keep_tokens_separator}
+          caption_dropout_rate: {subset.caption_dropout_rate}
+          caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
+          caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
+          caption_prefix: {subset.caption_prefix}
+          caption_suffix: {subset.caption_suffix}
+          color_aug: {subset.color_aug}
+          flip_aug: {subset.flip_aug}
+          face_crop_aug_range: {subset.face_crop_aug_range}
+          random_crop: {subset.random_crop}
+          token_warmup_min: {subset.token_warmup_min},
+          token_warmup_step: {subset.token_warmup_step},
+      """), "  ")
+
+      if is_dreambooth:
+        info += indent(dedent(f"""\
+          is_reg: {subset.is_reg}
+          class_tokens: {subset.class_tokens}
+          caption_extension: {subset.caption_extension}
+        \n"""), "    ")
+      elif not is_controlnet:
+        info += indent(dedent(f"""\
+          metadata_file: {subset.metadata_file}
+        \n"""), "    ")
+
+  print(info)
+
+  # make buckets first because it determines the length of dataset
+  # and set the same seed for all datasets
+  seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
+  for i, dataset in enumerate(datasets):
+    print(f"[Dataset {i}]")
+    dataset.make_buckets()
+    dataset.set_seed(seed)
+
+  return DatasetGroup(datasets)
+
+
+def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
+  def extract_dreambooth_params(name: str) -> Tuple[int, str]:
+    tokens = name.split('_')
+    try:
+      n_repeats = int(tokens[0])
+    except ValueError as e:
+      print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
+      return 0, ""
+    caption_by_folder = '_'.join(tokens[1:])
+    return n_repeats, caption_by_folder
+
+  def generate(base_dir: Optional[str], is_reg: bool):
+    if base_dir is None:
+      return []
+
+    base_dir: Path = Path(base_dir)
+    if not base_dir.is_dir():
+      return []
+
+    subsets_config = []
+    for subdir in base_dir.iterdir():
+      if not subdir.is_dir():
+        continue
+
+      num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
+      if num_repeats < 1:
+        continue
+
+      subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
+      subsets_config.append(subset_config)
+
+    return subsets_config
+
+  subsets_config = []
+  subsets_config += generate(train_data_dir, False)
+  subsets_config += generate(reg_data_dir, True)
+
+  return subsets_config
+
+
+def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
+  def generate(base_dir: Optional[str]):
+    if base_dir is None:
+      return []
+
+    base_dir: Path = Path(base_dir)
+    if not base_dir.is_dir():
+      return []
+
+    subsets_config = []
+    subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
+    subsets_config.append(subset_config)
+
+    return subsets_config
+
+  subsets_config = []
+  subsets_config += generate(train_data_dir)
+
+  return subsets_config
+
+
+def load_user_config(file: str) -> dict:
+  file: Path = Path(file)
+  if not file.is_file():
+    raise ValueError(f"file not found / ファイルが見つかりません: {file}")
+
+  if file.name.lower().endswith('.json'):
+    try:
+      with open(file, 'r') as f:
+        config = json.load(f)
+    except Exception:
+      print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
+      raise
+  elif file.name.lower().endswith('.toml'):
+    try:
+      config = toml.load(file)
+    except Exception:
+      print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
+      raise
+  else:
+    raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
+
+  return config
+
+# for config test
+if __name__ == "__main__":
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--support_dreambooth", action="store_true")
+  parser.add_argument("--support_finetuning", action="store_true")
+  parser.add_argument("--support_controlnet", action="store_true")
+  parser.add_argument("--support_dropout", action="store_true")
+  parser.add_argument("dataset_config")
+  config_args, remain = parser.parse_known_args()
+
+  parser = argparse.ArgumentParser()
+  train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
+  train_util.add_training_arguments(parser, config_args.support_dreambooth)
+  argparse_namespace = parser.parse_args(remain)
+  train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
+
+  print("[argparse_namespace]")
+  print(vars(argparse_namespace))
+
+  user_config = load_user_config(config_args.dataset_config)
+
+  print("\n[user_config]")
+  print(user_config)
+
+  sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout)
+  sanitized_user_config = sanitizer.sanitize_user_config(user_config)
+
+  print("\n[sanitized_user_config]")
+  print(sanitized_user_config)
+
+  blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
+
+  print("\n[blueprint]")
+  print(blueprint)
diff --git a/external/llite/library/custom_train_functions.py b/external/llite/library/custom_train_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..629e1a2ebe0a0df63b637217fbf80d6f558a89db
--- /dev/null
+++ b/external/llite/library/custom_train_functions.py
@@ -0,0 +1,529 @@
+import torch
+import argparse
+import random
+import re
+from typing import List, Optional, Union
+
+
+def prepare_scheduler_for_custom_training(noise_scheduler, device):
+    if hasattr(noise_scheduler, "all_snr"):
+        return
+
+    alphas_cumprod = noise_scheduler.alphas_cumprod
+    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+    alpha = sqrt_alphas_cumprod
+    sigma = sqrt_one_minus_alphas_cumprod
+    all_snr = (alpha / sigma) ** 2
+
+    noise_scheduler.all_snr = all_snr.to(device)
+
+
+def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
+    # fix beta: zero terminal SNR
+    print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
+
+    def enforce_zero_terminal_snr(betas):
+        # Convert betas to alphas_bar_sqrt
+        alphas = 1 - betas
+        alphas_bar = alphas.cumprod(0)
+        alphas_bar_sqrt = alphas_bar.sqrt()
+
+        # Store old values.
+        alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+        alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+        # Shift so last timestep is zero.
+        alphas_bar_sqrt -= alphas_bar_sqrt_T
+        # Scale so first timestep is back to old value.
+        alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+        # Convert alphas_bar_sqrt to betas
+        alphas_bar = alphas_bar_sqrt**2
+        alphas = alphas_bar[1:] / alphas_bar[:-1]
+        alphas = torch.cat([alphas_bar[0:1], alphas])
+        betas = 1 - alphas
+        return betas
+
+    betas = noise_scheduler.betas
+    betas = enforce_zero_terminal_snr(betas)
+    alphas = 1.0 - betas
+    alphas_cumprod = torch.cumprod(alphas, dim=0)
+
+    # print("original:", noise_scheduler.betas)
+    # print("fixed:", betas)
+
+    noise_scheduler.betas = betas
+    noise_scheduler.alphas = alphas
+    noise_scheduler.alphas_cumprod = alphas_cumprod
+
+
+def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
+    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
+    min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
+    if v_prediction:
+        snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
+    else:
+        snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
+    loss = loss * snr_weight
+    return loss
+
+
+def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
+    scale = get_snr_scale(timesteps, noise_scheduler)
+    loss = loss * scale
+    return loss
+
+
+def get_snr_scale(timesteps, noise_scheduler):
+    snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
+    snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)  # if timestep is 0, snr_t is inf, so limit it to 1000
+    scale = snr_t / (snr_t + 1)
+    # # show debug info
+    # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
+    return scale
+
+
+def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
+    scale = get_snr_scale(timesteps, noise_scheduler)
+    # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
+    loss = loss + loss / scale * v_pred_like_loss
+    return loss
+
+def apply_debiased_estimation(loss, timesteps, noise_scheduler):
+    snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
+    snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)  # if timestep is 0, snr_t is inf, so limit it to 1000
+    weight = 1/torch.sqrt(snr_t)
+    loss = weight * loss
+    return loss
+
+# TODO train_utilと分散しているのでどちらかに寄せる
+
+
+def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
+    parser.add_argument(
+        "--min_snr_gamma",
+        type=float,
+        default=None,
+        help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
+    )
+    parser.add_argument(
+        "--scale_v_pred_loss_like_noise_pred",
+        action="store_true",
+        help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
+    )
+    parser.add_argument(
+        "--v_pred_like_loss",
+        type=float,
+        default=None,
+        help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
+    )
+    parser.add_argument(
+        "--debiased_estimation_loss",
+        action="store_true",
+        help="debiased estimation loss / debiased estimation loss",
+    )
+    if support_weighted_captions:
+        parser.add_argument(
+            "--weighted_captions",
+            action="store_true",
+            default=False,
+            help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
+        )
+
+
+re_attention = re.compile(
+    r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+    re.X,
+)
+
+
+def parse_prompt_attention(text):
+    """
+    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+    Accepted tokens are:
+      (abc) - increases attention to abc by a multiplier of 1.1
+      (abc:3.12) - increases attention to abc by a multiplier of 3.12
+      [abc] - decreases attention to abc by a multiplier of 1.1
+      \( - literal character '('
+      \[ - literal character '['
+      \) - literal character ')'
+      \] - literal character ']'
+      \\ - literal character '\'
+      anything else - just text
+    >>> parse_prompt_attention('normal text')
+    [['normal text', 1.0]]
+    >>> parse_prompt_attention('an (important) word')
+    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+    >>> parse_prompt_attention('(unbalanced')
+    [['unbalanced', 1.1]]
+    >>> parse_prompt_attention('\(literal\]')
+    [['(literal]', 1.0]]
+    >>> parse_prompt_attention('(unnecessary)(parens)')
+    [['unnecessaryparens', 1.1]]
+    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+    [['a ', 1.0],
+     ['house', 1.5730000000000004],
+     [' ', 1.1],
+     ['on', 1.0],
+     [' a ', 1.1],
+     ['hill', 0.55],
+     [', sun, ', 1.1],
+     ['sky', 1.4641000000000006],
+     ['.', 1.1]]
+    """
+
+    res = []
+    round_brackets = []
+    square_brackets = []
+
+    round_bracket_multiplier = 1.1
+    square_bracket_multiplier = 1 / 1.1
+
+    def multiply_range(start_position, multiplier):
+        for p in range(start_position, len(res)):
+            res[p][1] *= multiplier
+
+    for m in re_attention.finditer(text):
+        text = m.group(0)
+        weight = m.group(1)
+
+        if text.startswith("\\"):
+            res.append([text[1:], 1.0])
+        elif text == "(":
+            round_brackets.append(len(res))
+        elif text == "[":
+            square_brackets.append(len(res))
+        elif weight is not None and len(round_brackets) > 0:
+            multiply_range(round_brackets.pop(), float(weight))
+        elif text == ")" and len(round_brackets) > 0:
+            multiply_range(round_brackets.pop(), round_bracket_multiplier)
+        elif text == "]" and len(square_brackets) > 0:
+            multiply_range(square_brackets.pop(), square_bracket_multiplier)
+        else:
+            res.append([text, 1.0])
+
+    for pos in round_brackets:
+        multiply_range(pos, round_bracket_multiplier)
+
+    for pos in square_brackets:
+        multiply_range(pos, square_bracket_multiplier)
+
+    if len(res) == 0:
+        res = [["", 1.0]]
+
+    # merge runs of identical weights
+    i = 0
+    while i + 1 < len(res):
+        if res[i][1] == res[i + 1][1]:
+            res[i][0] += res[i + 1][0]
+            res.pop(i + 1)
+        else:
+            i += 1
+
+    return res
+
+
+def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
+    r"""
+    Tokenize a list of prompts and return its tokens with weights of each token.
+
+    No padding, starting or ending token is included.
+    """
+    tokens = []
+    weights = []
+    truncated = False
+    for text in prompt:
+        texts_and_weights = parse_prompt_attention(text)
+        text_token = []
+        text_weight = []
+        for word, weight in texts_and_weights:
+            # tokenize and discard the starting and the ending token
+            token = tokenizer(word).input_ids[1:-1]
+            text_token += token
+            # copy the weight by length of token
+            text_weight += [weight] * len(token)
+            # stop if the text is too long (longer than truncation limit)
+            if len(text_token) > max_length:
+                truncated = True
+                break
+        # truncate
+        if len(text_token) > max_length:
+            truncated = True
+            text_token = text_token[:max_length]
+            text_weight = text_weight[:max_length]
+        tokens.append(text_token)
+        weights.append(text_weight)
+    if truncated:
+        print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+    return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
+    r"""
+    Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+    """
+    max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+    weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+    for i in range(len(tokens)):
+        tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
+        if no_boseos_middle:
+            weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+        else:
+            w = []
+            if len(weights[i]) == 0:
+                w = [1.0] * weights_length
+            else:
+                for j in range(max_embeddings_multiples):
+                    w.append(1.0)  # weight for starting token in this chunk
+                    w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+                    w.append(1.0)  # weight for ending token in this chunk
+                w += [1.0] * (weights_length - len(w))
+            weights[i] = w[:]
+
+    return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+    tokenizer,
+    text_encoder,
+    text_input: torch.Tensor,
+    chunk_length: int,
+    clip_skip: int,
+    eos: int,
+    pad: int,
+    no_boseos_middle: Optional[bool] = True,
+):
+    """
+    When the length of tokens is a multiple of the capacity of the text encoder,
+    it should be split into chunks and sent to the text encoder individually.
+    """
+    max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+    if max_embeddings_multiples > 1:
+        text_embeddings = []
+        for i in range(max_embeddings_multiples):
+            # extract the i-th chunk
+            text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+            # cover the head and the tail by the starting and the ending tokens
+            text_input_chunk[:, 0] = text_input[0, 0]
+            if pad == eos:  # v1
+                text_input_chunk[:, -1] = text_input[0, -1]
+            else:  # v2
+                for j in range(len(text_input_chunk)):
+                    if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
+                        text_input_chunk[j, -1] = eos
+                    if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
+                        text_input_chunk[j, 1] = eos
+
+            if clip_skip is None or clip_skip == 1:
+                text_embedding = text_encoder(text_input_chunk)[0]
+            else:
+                enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
+                text_embedding = enc_out["hidden_states"][-clip_skip]
+                text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
+
+            if no_boseos_middle:
+                if i == 0:
+                    # discard the ending token
+                    text_embedding = text_embedding[:, :-1]
+                elif i == max_embeddings_multiples - 1:
+                    # discard the starting token
+                    text_embedding = text_embedding[:, 1:]
+                else:
+                    # discard both starting and ending tokens
+                    text_embedding = text_embedding[:, 1:-1]
+
+            text_embeddings.append(text_embedding)
+        text_embeddings = torch.concat(text_embeddings, axis=1)
+    else:
+        if clip_skip is None or clip_skip == 1:
+            text_embeddings = text_encoder(text_input)[0]
+        else:
+            enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
+            text_embeddings = enc_out["hidden_states"][-clip_skip]
+            text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
+    return text_embeddings
+
+
+def get_weighted_text_embeddings(
+    tokenizer,
+    text_encoder,
+    prompt: Union[str, List[str]],
+    device,
+    max_embeddings_multiples: Optional[int] = 3,
+    no_boseos_middle: Optional[bool] = False,
+    clip_skip=None,
+):
+    r"""
+    Prompts can be assigned with local weights using brackets. For example,
+    prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+    and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+    Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+    Args:
+        prompt (`str` or `List[str]`):
+            The prompt or prompts to guide the image generation.
+        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+            The max multiple length of prompt embeddings compared to the max output length of text encoder.
+        no_boseos_middle (`bool`, *optional*, defaults to `False`):
+            If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+            ending token in each of the chunk in the middle.
+        skip_parsing (`bool`, *optional*, defaults to `False`):
+            Skip the parsing of brackets.
+        skip_weighting (`bool`, *optional*, defaults to `False`):
+            Skip the weighting. When the parsing is skipped, it is forced True.
+    """
+    max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+    if isinstance(prompt, str):
+        prompt = [prompt]
+
+    prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
+
+    # round up the longest length of tokens to a multiple of (model_max_length - 2)
+    max_length = max([len(token) for token in prompt_tokens])
+
+    max_embeddings_multiples = min(
+        max_embeddings_multiples,
+        (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
+    )
+    max_embeddings_multiples = max(1, max_embeddings_multiples)
+    max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+    # pad the length of tokens and weights
+    bos = tokenizer.bos_token_id
+    eos = tokenizer.eos_token_id
+    pad = tokenizer.pad_token_id
+    prompt_tokens, prompt_weights = pad_tokens_and_weights(
+        prompt_tokens,
+        prompt_weights,
+        max_length,
+        bos,
+        eos,
+        no_boseos_middle=no_boseos_middle,
+        chunk_length=tokenizer.model_max_length,
+    )
+    prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
+
+    # get the embeddings
+    text_embeddings = get_unweighted_text_embeddings(
+        tokenizer,
+        text_encoder,
+        prompt_tokens,
+        tokenizer.model_max_length,
+        clip_skip,
+        eos,
+        pad,
+        no_boseos_middle=no_boseos_middle,
+    )
+    prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
+
+    # assign weights to the prompts and normalize in the sense of mean
+    previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+    text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
+    current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+    text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+
+    return text_embeddings
+
+
+# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
+def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
+    b, c, w, h = noise.shape  # EDIT: w and h get over-written, rename for a different variant!
+    u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
+    for i in range(iterations):
+        r = random.random() * 2 + 2  # Rather than always going 2x,
+        wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
+        noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
+        if wn == 1 or hn == 1:
+            break  # Lowest resolution is 1x1
+    return noise / noise.std()  # Scaled back to roughly unit variance
+
+
+# https://www.crosslabs.org//blog/diffusion-with-offset-noise
+def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
+    if noise_offset is None:
+        return noise
+    if adaptive_noise_scale is not None:
+        # latent shape: (batch_size, channels, height, width)
+        # abs mean value for each channel
+        latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
+
+        # multiply adaptive noise scale to the mean value and add it to the noise offset
+        noise_offset = noise_offset + adaptive_noise_scale * latent_mean
+        noise_offset = torch.clamp(noise_offset, 0.0, None)  # in case of adaptive noise scale is negative
+
+    noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
+    return noise
+
+
+"""
+##########################################
+# Perlin Noise
+def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
+    delta = (res[0] / shape[0], res[1] / shape[1])
+    d = (shape[0] // res[0], shape[1] // res[1])
+
+    grid = (
+        torch.stack(
+            torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
+            dim=-1,
+        )
+        % 1
+    )
+    angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
+    gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
+
+    tile_grads = (
+        lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
+        .repeat_interleave(d[0], 0)
+        .repeat_interleave(d[1], 1)
+    )
+    dot = lambda grad, shift: (
+        torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
+        * grad[: shape[0], : shape[1]]
+    ).sum(dim=-1)
+
+    n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
+    n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
+    n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
+    n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
+    t = fade(grid[: shape[0], : shape[1]])
+    return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
+
+
+def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
+    noise = torch.zeros(shape, device=device)
+    frequency = 1
+    amplitude = 1
+    for _ in range(octaves):
+        noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
+        frequency *= 2
+        amplitude *= persistence
+    return noise
+
+
+def perlin_noise(noise, device, octaves):
+    _, c, w, h = noise.shape
+    perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
+    noise_perlin = []
+    for _ in range(c):
+        noise_perlin.append(perlin())
+    noise_perlin = torch.stack(noise_perlin).unsqueeze(0)   # (1, c, w, h)
+    noise += noise_perlin # broadcast for each batch
+    return noise / noise.std()  # Scaled back to roughly unit variance
+"""
diff --git a/external/llite/library/huggingface_util.py b/external/llite/library/huggingface_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef65af9ce5adc691019e0bae0e0d8f47a1d7eeb2
--- /dev/null
+++ b/external/llite/library/huggingface_util.py
@@ -0,0 +1,81 @@
+from typing import Union, BinaryIO
+from huggingface_hub import HfApi
+from pathlib import Path
+import argparse
+import os
+from external.llite.library.utils import fire_in_thread
+
+
+def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
+    api = HfApi(
+        token=token,
+    )
+    try:
+        api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
+        return True
+    except:
+        return False
+
+
+def upload(
+    args: argparse.Namespace,
+    src: Union[str, Path, bytes, BinaryIO],
+    dest_suffix: str = "",
+    force_sync_upload: bool = False,
+):
+    repo_id = args.huggingface_repo_id
+    repo_type = args.huggingface_repo_type
+    token = args.huggingface_token
+    path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
+    private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
+    api = HfApi(token=token)
+    if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
+        try:
+            api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
+        except Exception as e:  # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
+            print("===========================================")
+            print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
+            print("===========================================")
+
+    is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
+
+    def uploader():
+        try:
+            if is_folder:
+                api.upload_folder(
+                    repo_id=repo_id,
+                    repo_type=repo_type,
+                    folder_path=src,
+                    path_in_repo=path_in_repo,
+                )
+            else:
+                api.upload_file(
+                    repo_id=repo_id,
+                    repo_type=repo_type,
+                    path_or_fileobj=src,
+                    path_in_repo=path_in_repo,
+                )
+        except Exception as e:  # RuntimeErrorを確認済みだが他にあると困るので
+            print("===========================================")
+            print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
+            print("===========================================")
+
+    if args.async_upload and not force_sync_upload:
+        fire_in_thread(uploader)
+    else:
+        uploader()
+
+
+def list_dir(
+    repo_id: str,
+    subfolder: str,
+    repo_type: str,
+    revision: str = "main",
+    token: str = None,
+):
+    api = HfApi(
+        token=token,
+    )
+    repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
+    file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
+    return file_list
diff --git a/external/llite/library/hypernetwork.py b/external/llite/library/hypernetwork.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbd3fb24e1a5bc314b282407d1c6282a197d96a3
--- /dev/null
+++ b/external/llite/library/hypernetwork.py
@@ -0,0 +1,223 @@
+import torch
+import torch.nn.functional as F
+from diffusers.models.attention_processor import (
+    Attention,
+    AttnProcessor2_0,
+    SlicedAttnProcessor,
+    XFormersAttnProcessor
+)
+
+try:
+    import xformers.ops
+except:
+    xformers = None
+
+
+loaded_networks = []
+
+
+def apply_single_hypernetwork(
+    hypernetwork, hidden_states, encoder_hidden_states
+):
+    context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
+    return context_k, context_v
+
+
+def apply_hypernetworks(context_k, context_v, layer=None):
+    if len(loaded_networks) == 0:
+        return context_v, context_v
+    for hypernetwork in loaded_networks:
+        context_k, context_v = hypernetwork.forward(context_k, context_v)
+
+    context_k = context_k.to(dtype=context_k.dtype)
+    context_v = context_v.to(dtype=context_k.dtype)
+
+    return context_k, context_v
+
+
+
+def xformers_forward(
+    self: XFormersAttnProcessor,
+    attn: Attention,
+    hidden_states: torch.Tensor,
+    encoder_hidden_states: torch.Tensor = None,
+    attention_mask: torch.Tensor = None,
+):
+    batch_size, sequence_length, _ = (
+        hidden_states.shape
+        if encoder_hidden_states is None
+        else encoder_hidden_states.shape
+    )
+
+    attention_mask = attn.prepare_attention_mask(
+        attention_mask, sequence_length, batch_size
+    )
+
+    query = attn.to_q(hidden_states)
+
+    if encoder_hidden_states is None:
+        encoder_hidden_states = hidden_states
+    elif attn.norm_cross:
+        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+    context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
+
+    key = attn.to_k(context_k)
+    value = attn.to_v(context_v)
+
+    query = attn.head_to_batch_dim(query).contiguous()
+    key = attn.head_to_batch_dim(key).contiguous()
+    value = attn.head_to_batch_dim(value).contiguous()
+
+    hidden_states = xformers.ops.memory_efficient_attention(
+        query,
+        key,
+        value,
+        attn_bias=attention_mask,
+        op=self.attention_op,
+        scale=attn.scale,
+    )
+    hidden_states = hidden_states.to(query.dtype)
+    hidden_states = attn.batch_to_head_dim(hidden_states)
+
+    # linear proj
+    hidden_states = attn.to_out[0](hidden_states)
+    # dropout
+    hidden_states = attn.to_out[1](hidden_states)
+    return hidden_states
+
+
+def sliced_attn_forward(
+    self: SlicedAttnProcessor,
+    attn: Attention,
+    hidden_states: torch.Tensor,
+    encoder_hidden_states: torch.Tensor = None,
+    attention_mask: torch.Tensor = None,
+):
+    batch_size, sequence_length, _ = (
+        hidden_states.shape
+        if encoder_hidden_states is None
+        else encoder_hidden_states.shape
+    )
+    attention_mask = attn.prepare_attention_mask(
+        attention_mask, sequence_length, batch_size
+    )
+
+    query = attn.to_q(hidden_states)
+    dim = query.shape[-1]
+    query = attn.head_to_batch_dim(query)
+
+    if encoder_hidden_states is None:
+        encoder_hidden_states = hidden_states
+    elif attn.norm_cross:
+        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+    context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
+
+    key = attn.to_k(context_k)
+    value = attn.to_v(context_v)
+    key = attn.head_to_batch_dim(key)
+    value = attn.head_to_batch_dim(value)
+
+    batch_size_attention, query_tokens, _ = query.shape
+    hidden_states = torch.zeros(
+        (batch_size_attention, query_tokens, dim // attn.heads),
+        device=query.device,
+        dtype=query.dtype,
+    )
+
+    for i in range(batch_size_attention // self.slice_size):
+        start_idx = i * self.slice_size
+        end_idx = (i + 1) * self.slice_size
+
+        query_slice = query[start_idx:end_idx]
+        key_slice = key[start_idx:end_idx]
+        attn_mask_slice = (
+            attention_mask[start_idx:end_idx] if attention_mask is not None else None
+        )
+
+        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+        hidden_states[start_idx:end_idx] = attn_slice
+
+    hidden_states = attn.batch_to_head_dim(hidden_states)
+
+    # linear proj
+    hidden_states = attn.to_out[0](hidden_states)
+    # dropout
+    hidden_states = attn.to_out[1](hidden_states)
+
+    return hidden_states
+
+
+def v2_0_forward(
+    self: AttnProcessor2_0,
+    attn: Attention,
+    hidden_states,
+    encoder_hidden_states=None,
+    attention_mask=None,
+):
+    batch_size, sequence_length, _ = (
+        hidden_states.shape
+        if encoder_hidden_states is None
+        else encoder_hidden_states.shape
+    )
+    inner_dim = hidden_states.shape[-1]
+
+    if attention_mask is not None:
+        attention_mask = attn.prepare_attention_mask(
+            attention_mask, sequence_length, batch_size
+        )
+        # scaled_dot_product_attention expects attention_mask shape to be
+        # (batch, heads, source_length, target_length)
+        attention_mask = attention_mask.view(
+            batch_size, attn.heads, -1, attention_mask.shape[-1]
+        )
+
+    query = attn.to_q(hidden_states)
+
+    if encoder_hidden_states is None:
+        encoder_hidden_states = hidden_states
+    elif attn.norm_cross:
+        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+    context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
+
+    key = attn.to_k(context_k)
+    value = attn.to_v(context_v)
+
+    head_dim = inner_dim // attn.heads
+    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+    # the output of sdp = (batch, num_heads, seq_len, head_dim)
+    # TODO: add support for attn.scale when we move to Torch 2.1
+    hidden_states = F.scaled_dot_product_attention(
+        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+    )
+
+    hidden_states = hidden_states.transpose(1, 2).reshape(
+        batch_size, -1, attn.heads * head_dim
+    )
+    hidden_states = hidden_states.to(query.dtype)
+
+    # linear proj
+    hidden_states = attn.to_out[0](hidden_states)
+    # dropout
+    hidden_states = attn.to_out[1](hidden_states)
+    return hidden_states
+
+
+def replace_attentions_for_hypernetwork():
+    import diffusers.models.attention_processor
+
+    diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
+        xformers_forward
+    )
+    diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
+        sliced_attn_forward
+    )
+    diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
diff --git a/external/llite/library/ipex/__init__.py b/external/llite/library/ipex/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78547915bdf162e7fef3562acfb9c2add18be23
--- /dev/null
+++ b/external/llite/library/ipex/__init__.py
@@ -0,0 +1,169 @@
+import os
+import sys
+import contextlib
+import torch
+import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
+from .hijacks import ipex_hijacks
+
+# pylint: disable=protected-access, missing-function-docstring, line-too-long
+
+def ipex_init(): # pylint: disable=too-many-statements
+    try:
+        # Replace cuda with xpu:
+        torch.cuda.current_device = torch.xpu.current_device
+        torch.cuda.current_stream = torch.xpu.current_stream
+        torch.cuda.device = torch.xpu.device
+        torch.cuda.device_count = torch.xpu.device_count
+        torch.cuda.device_of = torch.xpu.device_of
+        torch.cuda.get_device_name = torch.xpu.get_device_name
+        torch.cuda.get_device_properties = torch.xpu.get_device_properties
+        torch.cuda.init = torch.xpu.init
+        torch.cuda.is_available = torch.xpu.is_available
+        torch.cuda.is_initialized = torch.xpu.is_initialized
+        torch.cuda.is_current_stream_capturing = lambda: False
+        torch.cuda.set_device = torch.xpu.set_device
+        torch.cuda.stream = torch.xpu.stream
+        torch.cuda.synchronize = torch.xpu.synchronize
+        torch.cuda.Event = torch.xpu.Event
+        torch.cuda.Stream = torch.xpu.Stream
+        torch.cuda.FloatTensor = torch.xpu.FloatTensor
+        torch.Tensor.cuda = torch.Tensor.xpu
+        torch.Tensor.is_cuda = torch.Tensor.is_xpu
+        torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
+        torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
+        torch.cuda._initialized = torch.xpu.lazy_init._initialized
+        torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
+        torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
+        torch.cuda._tls = torch.xpu.lazy_init._tls
+        torch.cuda.threading = torch.xpu.lazy_init.threading
+        torch.cuda.traceback = torch.xpu.lazy_init.traceback
+        torch.cuda.Optional = torch.xpu.Optional
+        torch.cuda.__cached__ = torch.xpu.__cached__
+        torch.cuda.__loader__ = torch.xpu.__loader__
+        torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
+        torch.cuda.Tuple = torch.xpu.Tuple
+        torch.cuda.streams = torch.xpu.streams
+        torch.cuda._lazy_new = torch.xpu._lazy_new
+        torch.cuda.FloatStorage = torch.xpu.FloatStorage
+        torch.cuda.Any = torch.xpu.Any
+        torch.cuda.__doc__ = torch.xpu.__doc__
+        torch.cuda.default_generators = torch.xpu.default_generators
+        torch.cuda.HalfTensor = torch.xpu.HalfTensor
+        torch.cuda._get_device_index = torch.xpu._get_device_index
+        torch.cuda.__path__ = torch.xpu.__path__
+        torch.cuda.Device = torch.xpu.Device
+        torch.cuda.IntTensor = torch.xpu.IntTensor
+        torch.cuda.ByteStorage = torch.xpu.ByteStorage
+        torch.cuda.set_stream = torch.xpu.set_stream
+        torch.cuda.BoolStorage = torch.xpu.BoolStorage
+        torch.cuda.os = torch.xpu.os
+        torch.cuda.torch = torch.xpu.torch
+        torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
+        torch.cuda.Union = torch.xpu.Union
+        torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
+        torch.cuda.ShortTensor = torch.xpu.ShortTensor
+        torch.cuda.LongTensor = torch.xpu.LongTensor
+        torch.cuda.IntStorage = torch.xpu.IntStorage
+        torch.cuda.LongStorage = torch.xpu.LongStorage
+        torch.cuda.__annotations__ = torch.xpu.__annotations__
+        torch.cuda.__package__ = torch.xpu.__package__
+        torch.cuda.__builtins__ = torch.xpu.__builtins__
+        torch.cuda.CharTensor = torch.xpu.CharTensor
+        torch.cuda.List = torch.xpu.List
+        torch.cuda._lazy_init = torch.xpu._lazy_init
+        torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
+        torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
+        torch.cuda.ByteTensor = torch.xpu.ByteTensor
+        torch.cuda.StreamContext = torch.xpu.StreamContext
+        torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
+        torch.cuda.ShortStorage = torch.xpu.ShortStorage
+        torch.cuda._lazy_call = torch.xpu._lazy_call
+        torch.cuda.HalfStorage = torch.xpu.HalfStorage
+        torch.cuda.random = torch.xpu.random
+        torch.cuda._device = torch.xpu._device
+        torch.cuda.classproperty = torch.xpu.classproperty
+        torch.cuda.__name__ = torch.xpu.__name__
+        torch.cuda._device_t = torch.xpu._device_t
+        torch.cuda.warnings = torch.xpu.warnings
+        torch.cuda.__spec__ = torch.xpu.__spec__
+        torch.cuda.BoolTensor = torch.xpu.BoolTensor
+        torch.cuda.CharStorage = torch.xpu.CharStorage
+        torch.cuda.__file__ = torch.xpu.__file__
+        torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
+        # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
+
+        # Memory:
+        torch.cuda.memory = torch.xpu.memory
+        if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
+            torch.xpu.empty_cache = lambda: None
+        torch.cuda.empty_cache = torch.xpu.empty_cache
+        torch.cuda.memory_stats = torch.xpu.memory_stats
+        torch.cuda.memory_summary = torch.xpu.memory_summary
+        torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
+        torch.cuda.memory_allocated = torch.xpu.memory_allocated
+        torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
+        torch.cuda.memory_reserved = torch.xpu.memory_reserved
+        torch.cuda.memory_cached = torch.xpu.memory_reserved
+        torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
+        torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
+        torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
+        torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
+        torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
+        torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
+        torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
+
+        # RNG:
+        torch.cuda.get_rng_state = torch.xpu.get_rng_state
+        torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
+        torch.cuda.set_rng_state = torch.xpu.set_rng_state
+        torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
+        torch.cuda.manual_seed = torch.xpu.manual_seed
+        torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
+        torch.cuda.seed = torch.xpu.seed
+        torch.cuda.seed_all = torch.xpu.seed_all
+        torch.cuda.initial_seed = torch.xpu.initial_seed
+
+        # AMP:
+        torch.cuda.amp = torch.xpu.amp
+        if not hasattr(torch.cuda.amp, "common"):
+            torch.cuda.amp.common = contextlib.nullcontext()
+        torch.cuda.amp.common.amp_definitely_not_available = lambda: False
+        try:
+            torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
+        except Exception: # pylint: disable=broad-exception-caught
+            try:
+                from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
+                gradscaler_init()
+                torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
+            except Exception: # pylint: disable=broad-exception-caught
+                torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
+
+        # C
+        torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
+        ipex._C._DeviceProperties.major = 2023
+        ipex._C._DeviceProperties.minor = 2
+
+        # Fix functions with ipex:
+        torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
+        torch._utils._get_available_device_type = lambda: "xpu"
+        torch.has_cuda = True
+        torch.cuda.has_half = True
+        torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
+        torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
+        torch.version.cuda = "11.7"
+        torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
+        torch.cuda.get_device_properties.major = 11
+        torch.cuda.get_device_properties.minor = 7
+        torch.cuda.ipc_collect = lambda *args, **kwargs: None
+        torch.cuda.utilization = lambda *args, **kwargs: 0
+
+        ipex_hijacks()
+        if not torch.xpu.has_fp64_dtype():
+            try:
+                from .diffusers import ipex_diffusers
+                ipex_diffusers()
+            except Exception: # pylint: disable=broad-exception-caught
+                pass
+    except Exception as e:
+        return False, e
+    return True, None
diff --git a/external/llite/library/ipex/attention.py b/external/llite/library/ipex/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e61f2c9068b633fba76e25d35c125a33a22099a
--- /dev/null
+++ b/external/llite/library/ipex/attention.py
@@ -0,0 +1,151 @@
+import torch
+import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
+
+# pylint: disable=protected-access, missing-function-docstring, line-too-long
+
+original_torch_bmm = torch.bmm
+def torch_bmm_32_bit(input, mat2, *, out=None):
+    # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
+    batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
+    block_multiply = input.element_size()
+    slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
+    block_size = batch_size_attention * slice_block_size
+
+    split_slice_size = batch_size_attention
+    if block_size > 4:
+        do_split = True
+        # Find something divisible with the input_tokens
+        while (split_slice_size * slice_block_size) > 4:
+            split_slice_size = split_slice_size // 2
+            if split_slice_size <= 1:
+                split_slice_size = 1
+                break
+        split_2_slice_size = input_tokens
+        if split_slice_size * slice_block_size > 4:
+            slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
+            do_split_2 = True
+            # Find something divisible with the input_tokens
+            while (split_2_slice_size * slice_block_size_2) > 4:
+                split_2_slice_size = split_2_slice_size // 2
+                if split_2_slice_size <= 1:
+                    split_2_slice_size = 1
+                    break
+        else:
+            do_split_2 = False
+    else:
+        do_split = False
+
+    if do_split:
+        hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
+        for i in range(batch_size_attention // split_slice_size):
+            start_idx = i * split_slice_size
+            end_idx = (i + 1) * split_slice_size
+            if do_split_2:
+                for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
+                    start_idx_2 = i2 * split_2_slice_size
+                    end_idx_2 = (i2 + 1) * split_2_slice_size
+                    hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
+                        input[start_idx:end_idx, start_idx_2:end_idx_2],
+                        mat2[start_idx:end_idx, start_idx_2:end_idx_2],
+                        out=out
+                    )
+            else:
+                hidden_states[start_idx:end_idx] = original_torch_bmm(
+                    input[start_idx:end_idx],
+                    mat2[start_idx:end_idx],
+                    out=out
+                )
+    else:
+        return original_torch_bmm(input, mat2, out=out)
+    return hidden_states
+
+original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
+def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
+    # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
+    if len(query.shape) == 3:
+        batch_size_attention, query_tokens, shape_three = query.shape
+        shape_four = 1
+    else:
+        batch_size_attention, query_tokens, shape_three, shape_four = query.shape
+
+    block_multiply = query.element_size()
+    slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
+    block_size = batch_size_attention * slice_block_size
+
+    split_slice_size = batch_size_attention
+    if block_size > 4:
+        do_split = True
+        # Find something divisible with the batch_size_attention
+        while (split_slice_size * slice_block_size) > 4:
+            split_slice_size = split_slice_size // 2
+            if split_slice_size <= 1:
+                split_slice_size = 1
+                break
+        split_2_slice_size = query_tokens
+        if split_slice_size * slice_block_size > 4:
+            slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
+            do_split_2 = True
+            # Find something divisible with the query_tokens
+            while (split_2_slice_size * slice_block_size_2) > 4:
+                split_2_slice_size = split_2_slice_size // 2
+                if split_2_slice_size <= 1:
+                    split_2_slice_size = 1
+                    break
+            split_3_slice_size = shape_three
+            if split_2_slice_size * slice_block_size_2 > 4:
+                slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
+                do_split_3 = True
+                # Find something divisible with the shape_three
+                while (split_3_slice_size * slice_block_size_3) > 4:
+                    split_3_slice_size = split_3_slice_size // 2
+                    if split_3_slice_size <= 1:
+                        split_3_slice_size = 1
+                        break
+            else:
+                do_split_3 = False
+        else:
+            do_split_2 = False
+    else:
+        do_split = False
+
+    if do_split:
+        hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
+        for i in range(batch_size_attention // split_slice_size):
+            start_idx = i * split_slice_size
+            end_idx = (i + 1) * split_slice_size
+            if do_split_2:
+                for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
+                    start_idx_2 = i2 * split_2_slice_size
+                    end_idx_2 = (i2 + 1) * split_2_slice_size
+                    if do_split_3:
+                        for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
+                            start_idx_3 = i3 * split_3_slice_size
+                            end_idx_3 = (i3 + 1) * split_3_slice_size
+                            hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
+                                query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
+                                key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
+                                value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
+                                attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
+                                dropout_p=dropout_p, is_causal=is_causal
+                            )
+                    else:
+                        hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
+                            query[start_idx:end_idx, start_idx_2:end_idx_2],
+                            key[start_idx:end_idx, start_idx_2:end_idx_2],
+                            value[start_idx:end_idx, start_idx_2:end_idx_2],
+                            attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
+                            dropout_p=dropout_p, is_causal=is_causal
+                        )
+            else:
+                hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
+                    query[start_idx:end_idx],
+                    key[start_idx:end_idx],
+                    value[start_idx:end_idx],
+                    attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
+                    dropout_p=dropout_p, is_causal=is_causal
+                )
+    else:
+        return original_scaled_dot_product_attention(
+            query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
+        )
+    return hidden_states
diff --git a/external/llite/library/ipex/diffusers.py b/external/llite/library/ipex/diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32af507b3e8976b8f1a10d4d28f9943cb33d4b4
--- /dev/null
+++ b/external/llite/library/ipex/diffusers.py
@@ -0,0 +1,120 @@
+import torch
+import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
+import diffusers #0.24.0 # pylint: disable=import-error
+from diffusers.models.attention_processor import Attention
+
+# pylint: disable=protected-access, missing-function-docstring, line-too-long
+
+class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
+    r"""
+    Processor for implementing sliced attention.
+
+    Args:
+        slice_size (`int`, *optional*):
+            The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+            `attention_head_dim` must be a multiple of the `slice_size`.
+    """
+
+    def __init__(self, slice_size):
+        self.slice_size = slice_size
+
+    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
+        residual = hidden_states
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+        dim = query.shape[-1]
+        query = attn.head_to_batch_dim(query)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        batch_size_attention, query_tokens, shape_three = query.shape
+        hidden_states = torch.zeros(
+            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+        )
+
+        #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
+        block_multiply = query.element_size()
+        slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
+        block_size = query_tokens * slice_block_size
+        split_2_slice_size = query_tokens
+        if block_size > 4:
+            do_split_2 = True
+            #Find something divisible with the query_tokens
+            while (split_2_slice_size * slice_block_size) > 4:
+                split_2_slice_size = split_2_slice_size // 2
+                if split_2_slice_size <= 1:
+                    split_2_slice_size = 1
+                    break
+        else:
+            do_split_2 = False
+
+        for i in range(batch_size_attention // self.slice_size):
+            start_idx = i * self.slice_size
+            end_idx = (i + 1) * self.slice_size
+
+            if do_split_2:
+                for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
+                    start_idx_2 = i2 * split_2_slice_size
+                    end_idx_2 = (i2 + 1) * split_2_slice_size
+
+                    query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
+                    key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
+                    attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
+
+                    attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+                    attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
+
+                    hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
+            else:
+                query_slice = query[start_idx:end_idx]
+                key_slice = key[start_idx:end_idx]
+                attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+                attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+                attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+                hidden_states[start_idx:end_idx] = attn_slice
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+def ipex_diffusers():
+    #ARC GPUs can't allocate more than 4GB to a single block:
+    diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diff --git a/external/llite/library/ipex/gradscaler.py b/external/llite/library/ipex/gradscaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eb56bc2b821e8530557f517ebeaafa141b763a6
--- /dev/null
+++ b/external/llite/library/ipex/gradscaler.py
@@ -0,0 +1,183 @@
+from collections import defaultdict
+import torch
+import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
+import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
+
+# pylint: disable=protected-access, missing-function-docstring, line-too-long
+
+device_supports_fp64 = torch.xpu.has_fp64_dtype()
+OptState = ipex.cpu.autocast._grad_scaler.OptState
+_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
+_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
+
+def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
+    per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
+    per_device_found_inf = _MultiDeviceReplicator(found_inf)
+
+    # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
+    # There could be hundreds of grads, so we'd like to iterate through them just once.
+    # However, we don't know their devices or dtypes in advance.
+
+    # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
+    # Google says mypy struggles with defaultdicts type annotations.
+    per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
+    # sync grad to master weight
+    if hasattr(optimizer, "sync_grad"):
+        optimizer.sync_grad()
+    with torch.no_grad():
+        for group in optimizer.param_groups:
+            for param in group["params"]:
+                if param.grad is None:
+                    continue
+                if (not allow_fp16) and param.grad.dtype == torch.float16:
+                    raise ValueError("Attempting to unscale FP16 gradients.")
+                if param.grad.is_sparse:
+                    # is_coalesced() == False means the sparse grad has values with duplicate indices.
+                    # coalesce() deduplicates indices and adds all values that have the same index.
+                    # For scaled fp16 values, there's a good chance coalescing will cause overflow,
+                    # so we should check the coalesced _values().
+                    if param.grad.dtype is torch.float16:
+                        param.grad = param.grad.coalesce()
+                    to_unscale = param.grad._values()
+                else:
+                    to_unscale = param.grad
+
+                # -: is there a way to split by device and dtype without appending in the inner loop?
+                to_unscale = to_unscale.to("cpu")
+                per_device_and_dtype_grads[to_unscale.device][
+                    to_unscale.dtype
+                ].append(to_unscale)
+
+        for _, per_dtype_grads in per_device_and_dtype_grads.items():
+            for grads in per_dtype_grads.values():
+                core._amp_foreach_non_finite_check_and_unscale_(
+                    grads,
+                    per_device_found_inf.get("cpu"),
+                    per_device_inv_scale.get("cpu"),
+                )
+
+    return per_device_found_inf._per_device_tensors
+
+def unscale_(self, optimizer):
+    """
+    Divides ("unscales") the optimizer's gradient tensors by the scale factor.
+    :meth:`unscale_` is optional, serving cases where you need to
+    :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
+    between the backward pass(es) and :meth:`step`.
+    If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.
+    Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
+        ...
+        scaler.scale(loss).backward()
+        scaler.unscale_(optimizer)
+        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+        scaler.step(optimizer)
+        scaler.update()
+    Args:
+        optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.
+    .. warning::
+        :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
+        and only after all gradients for that optimizer's assigned parameters have been accumulated.
+        Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
+    .. warning::
+        :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
+    """
+    if not self._enabled:
+        return
+
+    self._check_scale_growth_tracker("unscale_")
+
+    optimizer_state = self._per_optimizer_states[id(optimizer)]
+
+    if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
+        raise RuntimeError(
+            "unscale_() has already been called on this optimizer since the last update()."
+        )
+    elif optimizer_state["stage"] is OptState.STEPPED:
+        raise RuntimeError("unscale_() is being called after step().")
+
+    # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
+    assert self._scale is not None
+    if device_supports_fp64:
+        inv_scale = self._scale.double().reciprocal().float()
+    else:
+        inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
+    found_inf = torch.full(
+        (1,), 0.0, dtype=torch.float32, device=self._scale.device
+    )
+
+    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
+        optimizer, inv_scale, found_inf, False
+    )
+    optimizer_state["stage"] = OptState.UNSCALED
+
+def update(self, new_scale=None):
+    """
+    Updates the scale factor.
+    If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
+    to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
+    the scale is multiplied by ``growth_factor`` to increase it.
+    Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
+    used directly, it's used to fill GradScaler's internal scale tensor. So if
+    ``new_scale`` was a tensor, later in-place changes to that tensor will not further
+    affect the scale GradScaler uses internally.)
+    Args:
+        new_scale (float or :class:`torch.FloatTensor`, optional, default=None):  New scale factor.
+    .. warning::
+        :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
+        been invoked for all optimizers used this iteration.
+    """
+    if not self._enabled:
+        return
+
+    _scale, _growth_tracker = self._check_scale_growth_tracker("update")
+
+    if new_scale is not None:
+        # Accept a new user-defined scale.
+        if isinstance(new_scale, float):
+            self._scale.fill_(new_scale)  # type: ignore[union-attr]
+        else:
+            reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
+            assert isinstance(new_scale, torch.FloatTensor), reason  # type: ignore[attr-defined]
+            assert new_scale.numel() == 1, reason
+            assert new_scale.requires_grad is False, reason
+            self._scale.copy_(new_scale)  # type: ignore[union-attr]
+    else:
+        # Consume shared inf/nan data collected from optimizers to update the scale.
+        # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
+        found_infs = [
+            found_inf.to(device="cpu", non_blocking=True)
+            for state in self._per_optimizer_states.values()
+            for found_inf in state["found_inf_per_device"].values()
+        ]
+
+        assert len(found_infs) > 0, "No inf checks were recorded prior to update."
+
+        found_inf_combined = found_infs[0]
+        if len(found_infs) > 1:
+            for i in range(1, len(found_infs)):
+                found_inf_combined += found_infs[i]
+
+        to_device = _scale.device
+        _scale = _scale.to("cpu")
+        _growth_tracker = _growth_tracker.to("cpu")
+
+        core._amp_update_scale_(
+            _scale,
+            _growth_tracker,
+            found_inf_combined,
+            self._growth_factor,
+            self._backoff_factor,
+            self._growth_interval,
+        )
+
+        _scale = _scale.to(to_device)
+        _growth_tracker = _growth_tracker.to(to_device)
+    # To prepare for next iteration, clear the data collected from optimizers this iteration.
+    self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
+
+def gradscaler_init():
+    torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
+    torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
+    torch.xpu.amp.GradScaler.unscale_ = unscale_
+    torch.xpu.amp.GradScaler.update = update
+    return torch.xpu.amp.GradScaler
diff --git a/external/llite/library/ipex/hijacks.py b/external/llite/library/ipex/hijacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb5f779f94b871ae93138e7062c01e35040bb652
--- /dev/null
+++ b/external/llite/library/ipex/hijacks.py
@@ -0,0 +1,252 @@
+import contextlib
+import importlib
+import torch
+import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
+
+# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
+
+class CondFunc: # pylint: disable=missing-class-docstring
+    def __new__(cls, orig_func, sub_func, cond_func):
+        self = super(CondFunc, cls).__new__(cls)
+        if isinstance(orig_func, str):
+            func_path = orig_func.split('.')
+            for i in range(len(func_path)-1, -1, -1):
+                try:
+                    resolved_obj = importlib.import_module('.'.join(func_path[:i]))
+                    break
+                except ImportError:
+                    pass
+            for attr_name in func_path[i:-1]:
+                resolved_obj = getattr(resolved_obj, attr_name)
+            orig_func = getattr(resolved_obj, func_path[-1])
+            setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+        self.__init__(orig_func, sub_func, cond_func)
+        return lambda *args, **kwargs: self(*args, **kwargs)
+    def __init__(self, orig_func, sub_func, cond_func):
+        self.__orig_func = orig_func
+        self.__sub_func = sub_func
+        self.__cond_func = cond_func
+    def __call__(self, *args, **kwargs):
+        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
+            return self.__sub_func(self.__orig_func, *args, **kwargs)
+        else:
+            return self.__orig_func(*args, **kwargs)
+
+_utils = torch.utils.data._utils
+def _shutdown_workers(self):
+    if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
+        return
+    if hasattr(self, "_shutdown") and not self._shutdown:
+        self._shutdown = True
+        try:
+            if hasattr(self, '_pin_memory_thread'):
+                self._pin_memory_thread_done_event.set()
+                self._worker_result_queue.put((None, None))
+                self._pin_memory_thread.join()
+                self._worker_result_queue.cancel_join_thread()
+                self._worker_result_queue.close()
+            self._workers_done_event.set()
+            for worker_id in range(len(self._workers)):
+                if self._persistent_workers or self._workers_status[worker_id]:
+                    self._mark_worker_as_unavailable(worker_id, shutdown=True)
+            for w in self._workers: # pylint: disable=invalid-name
+                w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
+            for q in self._index_queues: # pylint: disable=invalid-name
+                q.cancel_join_thread()
+                q.close()
+        finally:
+            if self._worker_pids_set:
+                torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
+                self._worker_pids_set = False
+            for w in self._workers: # pylint: disable=invalid-name
+                if w.is_alive():
+                    w.terminate()
+
+class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
+    def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
+        if isinstance(device_ids, list) and len(device_ids) > 1:
+            print("IPEX backend doesn't support DataParallel on multiple XPU devices")
+        return module.to("xpu")
+
+def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
+    return contextlib.nullcontext()
+
+def check_device(device):
+    return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
+
+def return_xpu(device):
+    return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
+
+def ipex_no_cuda(orig_func, *args, **kwargs):
+    torch.cuda.is_available = lambda: False
+    orig_func(*args, **kwargs)
+    torch.cuda.is_available = torch.xpu.is_available
+
+original_autocast = torch.autocast
+def ipex_autocast(*args, **kwargs):
+    if len(args) > 0 and args[0] == "cuda":
+        return original_autocast("xpu", *args[1:], **kwargs)
+    else:
+        return original_autocast(*args, **kwargs)
+
+# Embedding BF16
+original_torch_cat = torch.cat
+def torch_cat(tensor, *args, **kwargs):
+    if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
+        return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
+    else:
+        return original_torch_cat(tensor, *args, **kwargs)
+
+# Latent antialias:
+original_interpolate = torch.nn.functional.interpolate
+def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
+    if antialias or align_corners is not None:
+        return_device = tensor.device
+        return_dtype = tensor.dtype
+        return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
+        align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
+    else:
+        return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
+        align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
+
+original_linalg_solve = torch.linalg.solve
+def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
+    if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
+        return_device = A.device
+        return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
+    else:
+        return original_linalg_solve(A, B, *args, **kwargs)
+
+if torch.xpu.has_fp64_dtype():
+    original_torch_bmm = torch.bmm
+    original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
+else:
+    # 64 bit attention workarounds for Alchemist:
+    try:
+        from .attention import torch_bmm_32_bit as original_torch_bmm
+        from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
+    except Exception: # pylint: disable=broad-exception-caught
+        original_torch_bmm = torch.bmm
+        original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
+
+# dtype errors:
+def torch_bmm(input, mat2, *, out=None):
+    if input.dtype != mat2.dtype:
+        mat2 = mat2.to(input.dtype)
+    return original_torch_bmm(input, mat2, out=out)
+
+def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
+    if query.dtype != key.dtype:
+        key = key.to(dtype=query.dtype)
+    if query.dtype != value.dtype:
+        value = value.to(dtype=query.dtype)
+    return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
+
+@property
+def is_cuda(self):
+    return self.device.type == 'xpu'
+
+def ipex_hijacks():
+    CondFunc('torch.tensor',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.Tensor.to',
+        lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
+        lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
+    CondFunc('torch.Tensor.cuda',
+        lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
+        lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
+    CondFunc('torch.UntypedStorage.__init__',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.UntypedStorage.cuda',
+        lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
+        lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
+    CondFunc('torch.empty',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.randn',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.ones',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.zeros',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.linspace',
+        lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
+        lambda orig_func, *args, device=None, **kwargs: check_device(device))
+    CondFunc('torch.load',
+        lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
+        orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
+        lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
+    if hasattr(torch.xpu, "Generator"):
+        CondFunc('torch.Generator',
+            lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
+            lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
+    else:
+        CondFunc('torch.Generator',
+            lambda orig_func, device=None: orig_func(return_xpu(device)),
+            lambda orig_func, device=None: check_device(device))
+
+    # TiledVAE and ControlNet:
+    CondFunc('torch.batch_norm',
+        lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
+        weight if weight is not None else torch.ones(input.size()[1], device=input.device),
+        bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
+        lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
+    CondFunc('torch.instance_norm',
+        lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
+        weight if weight is not None else torch.ones(input.size()[1], device=input.device),
+        bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
+        lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
+
+    # Functions with dtype errors:
+    CondFunc('torch.nn.modules.GroupNorm.forward',
+        lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
+        lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+    # Training:
+    CondFunc('torch.nn.modules.linear.Linear.forward',
+        lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
+        lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+    CondFunc('torch.nn.modules.conv.Conv2d.forward',
+        lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
+        lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+    # BF16:
+    CondFunc('torch.nn.functional.layer_norm',
+        lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
+        orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
+        lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
+        weight is not None and input.dtype != weight.data.dtype)
+    # SwinIR BF16:
+    CondFunc('torch.nn.functional.pad',
+        lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
+        lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
+
+    # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
+    if not torch.xpu.has_fp64_dtype():
+        CondFunc('torch.from_numpy',
+        lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
+        lambda orig_func, ndarray: ndarray.dtype == float)
+
+    # Broken functions when torch.cuda.is_available is True:
+    # Pin Memory:
+    CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
+        lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
+        lambda orig_func, *args, **kwargs: True)
+
+    # Functions that make compile mad with CondFunc:
+    torch.nn.DataParallel = DummyDataParallel
+    torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
+
+    torch.autocast = ipex_autocast
+    torch.backends.cuda.sdp_kernel = return_null_context
+    torch.UntypedStorage.is_cuda = is_cuda
+
+    torch.nn.functional.interpolate = interpolate
+    torch.linalg.solve = linalg_solve
+
+    torch.bmm = torch_bmm
+    torch.cat = torch_cat
+    torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
diff --git a/external/llite/library/lpw_stable_diffusion.py b/external/llite/library/lpw_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dce91a76ea40395f884d1da53855a8834b7f150
--- /dev/null
+++ b/external/llite/library/lpw_stable_diffusion.py
@@ -0,0 +1,1254 @@
+# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
+# and modify to support SD2.x
+
+import inspect
+import re
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import SchedulerMixin, StableDiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.utils import logging
+
+
+try:
+    from diffusers.utils import PIL_INTERPOLATION
+except ImportError:
+    if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+        PIL_INTERPOLATION = {
+            "linear": PIL.Image.Resampling.BILINEAR,
+            "bilinear": PIL.Image.Resampling.BILINEAR,
+            "bicubic": PIL.Image.Resampling.BICUBIC,
+            "lanczos": PIL.Image.Resampling.LANCZOS,
+            "nearest": PIL.Image.Resampling.NEAREST,
+        }
+    else:
+        PIL_INTERPOLATION = {
+            "linear": PIL.Image.LINEAR,
+            "bilinear": PIL.Image.BILINEAR,
+            "bicubic": PIL.Image.BICUBIC,
+            "lanczos": PIL.Image.LANCZOS,
+            "nearest": PIL.Image.NEAREST,
+        }
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+re_attention = re.compile(
+    r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+    re.X,
+)
+
+
+def parse_prompt_attention(text):
+    """
+    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+    Accepted tokens are:
+      (abc) - increases attention to abc by a multiplier of 1.1
+      (abc:3.12) - increases attention to abc by a multiplier of 3.12
+      [abc] - decreases attention to abc by a multiplier of 1.1
+      \( - literal character '('
+      \[ - literal character '['
+      \) - literal character ')'
+      \] - literal character ']'
+      \\ - literal character '\'
+      anything else - just text
+    >>> parse_prompt_attention('normal text')
+    [['normal text', 1.0]]
+    >>> parse_prompt_attention('an (important) word')
+    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+    >>> parse_prompt_attention('(unbalanced')
+    [['unbalanced', 1.1]]
+    >>> parse_prompt_attention('\(literal\]')
+    [['(literal]', 1.0]]
+    >>> parse_prompt_attention('(unnecessary)(parens)')
+    [['unnecessaryparens', 1.1]]
+    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+    [['a ', 1.0],
+     ['house', 1.5730000000000004],
+     [' ', 1.1],
+     ['on', 1.0],
+     [' a ', 1.1],
+     ['hill', 0.55],
+     [', sun, ', 1.1],
+     ['sky', 1.4641000000000006],
+     ['.', 1.1]]
+    """
+
+    res = []
+    round_brackets = []
+    square_brackets = []
+
+    round_bracket_multiplier = 1.1
+    square_bracket_multiplier = 1 / 1.1
+
+    def multiply_range(start_position, multiplier):
+        for p in range(start_position, len(res)):
+            res[p][1] *= multiplier
+
+    for m in re_attention.finditer(text):
+        text = m.group(0)
+        weight = m.group(1)
+
+        if text.startswith("\\"):
+            res.append([text[1:], 1.0])
+        elif text == "(":
+            round_brackets.append(len(res))
+        elif text == "[":
+            square_brackets.append(len(res))
+        elif weight is not None and len(round_brackets) > 0:
+            multiply_range(round_brackets.pop(), float(weight))
+        elif text == ")" and len(round_brackets) > 0:
+            multiply_range(round_brackets.pop(), round_bracket_multiplier)
+        elif text == "]" and len(square_brackets) > 0:
+            multiply_range(square_brackets.pop(), square_bracket_multiplier)
+        else:
+            res.append([text, 1.0])
+
+    for pos in round_brackets:
+        multiply_range(pos, round_bracket_multiplier)
+
+    for pos in square_brackets:
+        multiply_range(pos, square_bracket_multiplier)
+
+    if len(res) == 0:
+        res = [["", 1.0]]
+
+    # merge runs of identical weights
+    i = 0
+    while i + 1 < len(res):
+        if res[i][1] == res[i + 1][1]:
+            res[i][0] += res[i + 1][0]
+            res.pop(i + 1)
+        else:
+            i += 1
+
+    return res
+
+
+def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
+    r"""
+    Tokenize a list of prompts and return its tokens with weights of each token.
+
+    No padding, starting or ending token is included.
+    """
+    tokens = []
+    weights = []
+    truncated = False
+    for text in prompt:
+        texts_and_weights = parse_prompt_attention(text)
+        text_token = []
+        text_weight = []
+        for word, weight in texts_and_weights:
+            # tokenize and discard the starting and the ending token
+            token = pipe.tokenizer(word).input_ids[1:-1]
+            text_token += token
+            # copy the weight by length of token
+            text_weight += [weight] * len(token)
+            # stop if the text is too long (longer than truncation limit)
+            if len(text_token) > max_length:
+                truncated = True
+                break
+        # truncate
+        if len(text_token) > max_length:
+            truncated = True
+            text_token = text_token[:max_length]
+            text_weight = text_weight[:max_length]
+        tokens.append(text_token)
+        weights.append(text_weight)
+    if truncated:
+        logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+    return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
+    r"""
+    Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+    """
+    max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+    weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+    for i in range(len(tokens)):
+        tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
+        if no_boseos_middle:
+            weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+        else:
+            w = []
+            if len(weights[i]) == 0:
+                w = [1.0] * weights_length
+            else:
+                for j in range(max_embeddings_multiples):
+                    w.append(1.0)  # weight for starting token in this chunk
+                    w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+                    w.append(1.0)  # weight for ending token in this chunk
+                w += [1.0] * (weights_length - len(w))
+            weights[i] = w[:]
+
+    return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+    pipe: StableDiffusionPipeline,
+    text_input: torch.Tensor,
+    chunk_length: int,
+    clip_skip: int,
+    eos: int,
+    pad: int,
+    no_boseos_middle: Optional[bool] = True,
+):
+    """
+    When the length of tokens is a multiple of the capacity of the text encoder,
+    it should be split into chunks and sent to the text encoder individually.
+    """
+    max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+    if max_embeddings_multiples > 1:
+        text_embeddings = []
+        for i in range(max_embeddings_multiples):
+            # extract the i-th chunk
+            text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+            # cover the head and the tail by the starting and the ending tokens
+            text_input_chunk[:, 0] = text_input[0, 0]
+            if pad == eos:  # v1
+                text_input_chunk[:, -1] = text_input[0, -1]
+            else:  # v2
+                for j in range(len(text_input_chunk)):
+                    if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
+                        text_input_chunk[j, -1] = eos
+                    if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
+                        text_input_chunk[j, 1] = eos
+
+            if clip_skip is None or clip_skip == 1:
+                text_embedding = pipe.text_encoder(text_input_chunk)[0]
+            else:
+                enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
+                text_embedding = enc_out["hidden_states"][-clip_skip]
+                text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
+
+            if no_boseos_middle:
+                if i == 0:
+                    # discard the ending token
+                    text_embedding = text_embedding[:, :-1]
+                elif i == max_embeddings_multiples - 1:
+                    # discard the starting token
+                    text_embedding = text_embedding[:, 1:]
+                else:
+                    # discard both starting and ending tokens
+                    text_embedding = text_embedding[:, 1:-1]
+
+            text_embeddings.append(text_embedding)
+        text_embeddings = torch.concat(text_embeddings, axis=1)
+    else:
+        if clip_skip is None or clip_skip == 1:
+            text_embeddings = pipe.text_encoder(text_input)[0]
+        else:
+            enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
+            text_embeddings = enc_out["hidden_states"][-clip_skip]
+            text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
+    return text_embeddings
+
+
+def get_weighted_text_embeddings(
+    pipe: StableDiffusionPipeline,
+    prompt: Union[str, List[str]],
+    uncond_prompt: Optional[Union[str, List[str]]] = None,
+    max_embeddings_multiples: Optional[int] = 3,
+    no_boseos_middle: Optional[bool] = False,
+    skip_parsing: Optional[bool] = False,
+    skip_weighting: Optional[bool] = False,
+    clip_skip=None,
+):
+    r"""
+    Prompts can be assigned with local weights using brackets. For example,
+    prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+    and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+    Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+    Args:
+        pipe (`StableDiffusionPipeline`):
+            Pipe to provide access to the tokenizer and the text encoder.
+        prompt (`str` or `List[str]`):
+            The prompt or prompts to guide the image generation.
+        uncond_prompt (`str` or `List[str]`):
+            The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+            is provided, the embeddings of prompt and uncond_prompt are concatenated.
+        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+            The max multiple length of prompt embeddings compared to the max output length of text encoder.
+        no_boseos_middle (`bool`, *optional*, defaults to `False`):
+            If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+            ending token in each of the chunk in the middle.
+        skip_parsing (`bool`, *optional*, defaults to `False`):
+            Skip the parsing of brackets.
+        skip_weighting (`bool`, *optional*, defaults to `False`):
+            Skip the weighting. When the parsing is skipped, it is forced True.
+    """
+    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+    if isinstance(prompt, str):
+        prompt = [prompt]
+
+    if not skip_parsing:
+        prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+        if uncond_prompt is not None:
+            if isinstance(uncond_prompt, str):
+                uncond_prompt = [uncond_prompt]
+            uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+    else:
+        prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
+        prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+        if uncond_prompt is not None:
+            if isinstance(uncond_prompt, str):
+                uncond_prompt = [uncond_prompt]
+            uncond_tokens = [
+                token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
+            ]
+            uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+    # round up the longest length of tokens to a multiple of (model_max_length - 2)
+    max_length = max([len(token) for token in prompt_tokens])
+    if uncond_prompt is not None:
+        max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+    max_embeddings_multiples = min(
+        max_embeddings_multiples,
+        (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+    )
+    max_embeddings_multiples = max(1, max_embeddings_multiples)
+    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+    # pad the length of tokens and weights
+    bos = pipe.tokenizer.bos_token_id
+    eos = pipe.tokenizer.eos_token_id
+    pad = pipe.tokenizer.pad_token_id
+    prompt_tokens, prompt_weights = pad_tokens_and_weights(
+        prompt_tokens,
+        prompt_weights,
+        max_length,
+        bos,
+        eos,
+        no_boseos_middle=no_boseos_middle,
+        chunk_length=pipe.tokenizer.model_max_length,
+    )
+    prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
+    if uncond_prompt is not None:
+        uncond_tokens, uncond_weights = pad_tokens_and_weights(
+            uncond_tokens,
+            uncond_weights,
+            max_length,
+            bos,
+            eos,
+            no_boseos_middle=no_boseos_middle,
+            chunk_length=pipe.tokenizer.model_max_length,
+        )
+        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
+
+    # get the embeddings
+    text_embeddings = get_unweighted_text_embeddings(
+        pipe,
+        prompt_tokens,
+        pipe.tokenizer.model_max_length,
+        clip_skip,
+        eos,
+        pad,
+        no_boseos_middle=no_boseos_middle,
+    )
+    prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
+    if uncond_prompt is not None:
+        uncond_embeddings = get_unweighted_text_embeddings(
+            pipe,
+            uncond_tokens,
+            pipe.tokenizer.model_max_length,
+            clip_skip,
+            eos,
+            pad,
+            no_boseos_middle=no_boseos_middle,
+        )
+        uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
+
+    # assign weights to the prompts and normalize in the sense of mean
+    # TODO: should we normalize by chunk or in a whole (current implementation)?
+    if (not skip_parsing) and (not skip_weighting):
+        previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+        text_embeddings *= prompt_weights.unsqueeze(-1)
+        current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+        text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+        if uncond_prompt is not None:
+            previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+            uncond_embeddings *= uncond_weights.unsqueeze(-1)
+            current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+            uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+
+    if uncond_prompt is not None:
+        return text_embeddings, uncond_embeddings
+    return text_embeddings, None
+
+
+def preprocess_image(image):
+    w, h = image.size
+    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
+    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+    image = np.array(image).astype(np.float32) / 255.0
+    image = image[None].transpose(0, 3, 1, 2)
+    image = torch.from_numpy(image)
+    return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+    mask = mask.convert("L")
+    w, h = mask.size
+    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
+    mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+    mask = np.array(mask).astype(np.float32) / 255.0
+    mask = np.tile(mask, (4, 1, 1))
+    mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?
+    mask = 1 - mask  # repaint white, keep black
+    mask = torch.from_numpy(mask)
+    return mask
+
+
+def prepare_controlnet_image(
+    image: PIL.Image.Image,
+    width: int,
+    height: int,
+    batch_size: int,
+    num_images_per_prompt: int,
+    device: torch.device,
+    dtype: torch.dtype,
+    do_classifier_free_guidance: bool = False,
+    guess_mode: bool = False,
+):
+    if not isinstance(image, torch.Tensor):
+        if isinstance(image, PIL.Image.Image):
+            image = [image]
+
+        if isinstance(image[0], PIL.Image.Image):
+            images = []
+
+            for image_ in image:
+                image_ = image_.convert("RGB")
+                image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+                image_ = np.array(image_)
+                image_ = image_[None, :]
+                images.append(image_)
+
+            image = images
+
+            image = np.concatenate(image, axis=0)
+            image = np.array(image).astype(np.float32) / 255.0
+            image = image.transpose(0, 3, 1, 2)
+            image = torch.from_numpy(image)
+        elif isinstance(image[0], torch.Tensor):
+            image = torch.cat(image, dim=0)
+
+    image_batch_size = image.shape[0]
+
+    if image_batch_size == 1:
+        repeat_by = batch_size
+    else:
+        # image batch size is the same as prompt batch size
+        repeat_by = num_images_per_prompt
+
+    image = image.repeat_interleave(repeat_by, dim=0)
+
+    image = image.to(device=device, dtype=dtype)
+
+    if do_classifier_free_guidance and not guess_mode:
+        image = torch.cat([image] * 2)
+
+    return image
+
+
+class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
+    weighting in prompt.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+        feature_extractor ([`CLIPFeatureExtractor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+
+    # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        tokenizer: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: SchedulerMixin,
+        # clip_skip: int,
+        safety_checker: StableDiffusionSafetyChecker,
+        feature_extractor: CLIPFeatureExtractor,
+        requires_safety_checker: bool = True,
+        clip_skip: int = 1,
+    ):
+        super().__init__(
+            vae=vae,
+            text_encoder=text_encoder,
+            tokenizer=tokenizer,
+            unet=unet,
+            scheduler=scheduler,
+            safety_checker=safety_checker,
+            feature_extractor=feature_extractor,
+            requires_safety_checker=requires_safety_checker,
+        )
+        self.clip_skip = clip_skip
+        self.__init__additional__()
+
+    # else:
+    #     def __init__(
+    #         self,
+    #         vae: AutoencoderKL,
+    #         text_encoder: CLIPTextModel,
+    #         tokenizer: CLIPTokenizer,
+    #         unet: UNet2DConditionModel,
+    #         scheduler: SchedulerMixin,
+    #         safety_checker: StableDiffusionSafetyChecker,
+    #         feature_extractor: CLIPFeatureExtractor,
+    #     ):
+    #         super().__init__(
+    #             vae=vae,
+    #             text_encoder=text_encoder,
+    #             tokenizer=tokenizer,
+    #             unet=unet,
+    #             scheduler=scheduler,
+    #             safety_checker=safety_checker,
+    #             feature_extractor=feature_extractor,
+    #         )
+    #         self.__init__additional__()
+
+    def __init__additional__(self):
+        if not hasattr(self, "vae_scale_factor"):
+            setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
+
+    @property
+    def _execution_device(self):
+        r"""
+        Returns the device on which the pipeline's models will be executed. After calling
+        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+        hooks.
+        """
+        if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+            return self.device
+        for module in self.unet.modules():
+            if (
+                hasattr(module, "_hf_hook")
+                and hasattr(module._hf_hook, "execution_device")
+                and module._hf_hook.execution_device is not None
+            ):
+                return torch.device(module._hf_hook.execution_device)
+        return self.device
+
+    def _encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt,
+        max_embeddings_multiples,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `list(int)`):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+        """
+        batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+        if negative_prompt is None:
+            negative_prompt = [""] * batch_size
+        elif isinstance(negative_prompt, str):
+            negative_prompt = [negative_prompt] * batch_size
+        if batch_size != len(negative_prompt):
+            raise ValueError(
+                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                " the batch size of `prompt`."
+            )
+
+        text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
+            pipe=self,
+            prompt=prompt,
+            uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+            max_embeddings_multiples=max_embeddings_multiples,
+            clip_skip=self.clip_skip,
+        )
+        bs_embed, seq_len, _ = text_embeddings.shape
+        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance:
+            bs_embed, seq_len, _ = uncond_embeddings.shape
+            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+            uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+        return text_embeddings
+
+    def check_inputs(self, prompt, height, width, strength, callback_steps):
+        if not isinstance(prompt, str) and not isinstance(prompt, list):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+        if height % 8 != 0 or width % 8 != 0:
+            print(height, width)
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if (callback_steps is None) or (
+            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+        ):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
+            )
+
+    def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
+        if is_text2img:
+            return self.scheduler.timesteps.to(device), num_inference_steps
+        else:
+            # get the original timestep using init_timestep
+            offset = self.scheduler.config.get("steps_offset", 0)
+            init_timestep = int(num_inference_steps * strength) + offset
+            init_timestep = min(init_timestep, num_inference_steps)
+
+            t_start = max(num_inference_steps - init_timestep + offset, 0)
+            timesteps = self.scheduler.timesteps[t_start:].to(device)
+            return timesteps, num_inference_steps - t_start
+
+    def run_safety_checker(self, image, device, dtype):
+        if self.safety_checker is not None:
+            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+            image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
+        else:
+            has_nsfw_concept = None
+        return image, has_nsfw_concept
+
+    def decode_latents(self, latents):
+        latents = 1 / 0.18215 * latents
+        image = self.vae.decode(latents).sample
+        image = (image / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+        return image
+
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
+        if image is None:
+            shape = (
+                batch_size,
+                self.unet.in_channels,
+                height // self.vae_scale_factor,
+                width // self.vae_scale_factor,
+            )
+
+            if latents is None:
+                if device.type == "mps":
+                    # randn does not work reproducibly on mps
+                    latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+                else:
+                    latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+            else:
+                if latents.shape != shape:
+                    raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+                latents = latents.to(device)
+
+            # scale the initial noise by the standard deviation required by the scheduler
+            latents = latents * self.scheduler.init_noise_sigma
+            return latents, None, None
+        else:
+            init_latent_dist = self.vae.encode(image).latent_dist
+            init_latents = init_latent_dist.sample(generator=generator)
+            init_latents = 0.18215 * init_latents
+            init_latents = torch.cat([init_latents] * batch_size, dim=0)
+            init_latents_orig = init_latents
+            shape = init_latents.shape
+
+            # add noise to latents using the timesteps
+            if device.type == "mps":
+                noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+            else:
+                noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+            latents = self.scheduler.add_noise(init_latents, noise, timestep)
+            return latents, init_latents_orig, noise
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+        mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+        height: int = 512,
+        width: int = 512,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        strength: float = 0.8,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[torch.Generator] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        controlnet=None,
+        controlnet_image=None,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, that will be used as the starting point for the
+                process.
+            mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            height (`int`, *optional*, defaults to 512):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to 512):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            strength (`float`, *optional*, defaults to 0.8):
+                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+                noise will be maximum and the denoising process will run for the full number of iterations specified in
+                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            controlnet (`diffusers.ControlNetModel`, *optional*):
+                A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
+            controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
+                `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
+                inference.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+
+        Returns:
+            `None` if cancelled by `is_cancelled_callback`,
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        if controlnet is not None and controlnet_image is None:
+            raise ValueError("controlnet_image must be provided if controlnet is not None.")
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(prompt, height, width, strength, callback_steps)
+
+        # 2. Define call parameters
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        text_embeddings = self._encode_prompt(
+            prompt,
+            device,
+            num_images_per_prompt,
+            do_classifier_free_guidance,
+            negative_prompt,
+            max_embeddings_multiples,
+        )
+        dtype = text_embeddings.dtype
+
+        # 4. Preprocess image and mask
+        if isinstance(image, PIL.Image.Image):
+            image = preprocess_image(image)
+        if image is not None:
+            image = image.to(device=self.device, dtype=dtype)
+        if isinstance(mask_image, PIL.Image.Image):
+            mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
+        if mask_image is not None:
+            mask = mask_image.to(device=self.device, dtype=dtype)
+            mask = torch.cat([mask] * batch_size * num_images_per_prompt)
+        else:
+            mask = None
+
+        if controlnet_image is not None:
+            controlnet_image = prepare_controlnet_image(
+                controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
+            )
+
+        # 5. set timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+        # 6. Prepare latent variables
+        latents, init_latents_orig, noise = self.prepare_latents(
+            image,
+            latent_timestep,
+            batch_size * num_images_per_prompt,
+            height,
+            width,
+            dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 8. Denoising loop
+        for i, t in enumerate(self.progress_bar(timesteps)):
+            # expand the latents if we are doing classifier free guidance
+            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+            unet_additional_args = {}
+            if controlnet is not None:
+                down_block_res_samples, mid_block_res_sample = controlnet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=text_embeddings,
+                    controlnet_cond=controlnet_image,
+                    conditioning_scale=1.0,
+                    guess_mode=False,
+                    return_dict=False,
+                )
+                unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
+                unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
+
+            # predict the noise residual
+            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
+
+            # perform guidance
+            if do_classifier_free_guidance:
+                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+            # compute the previous noisy sample x_t -> x_t-1
+            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+            if mask is not None:
+                # masking
+                init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
+                latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+            # call the callback, if provided
+            if i % callback_steps == 0:
+                if callback is not None:
+                    callback(i, t, latents)
+                if is_cancelled_callback is not None and is_cancelled_callback():
+                    return None
+
+        return latents
+
+    def latents_to_image(self, latents):
+        # 9. Post-processing
+        image = self.decode_latents(latents.to(self.vae.dtype))
+        image = self.numpy_to_pil(image)
+        return image
+
+    def text2img(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: int = 512,
+        width: int = 512,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[torch.Generator] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function for text-to-image generation.
+        Args:
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            height (`int`, *optional*, defaults to 512):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to 512):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        return self.__call__(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            height=height,
+            width=width,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            num_images_per_prompt=num_images_per_prompt,
+            eta=eta,
+            generator=generator,
+            latents=latents,
+            max_embeddings_multiples=max_embeddings_multiples,
+            output_type=output_type,
+            return_dict=return_dict,
+            callback=callback,
+            is_cancelled_callback=is_cancelled_callback,
+            callback_steps=callback_steps,
+        )
+
+    def img2img(
+        self,
+        image: Union[torch.FloatTensor, PIL.Image.Image],
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        strength: float = 0.8,
+        num_inference_steps: Optional[int] = 50,
+        guidance_scale: Optional[float] = 7.5,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: Optional[float] = 0.0,
+        generator: Optional[torch.Generator] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function for image-to-image generation.
+        Args:
+            image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, that will be used as the starting point for the
+                process.
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            strength (`float`, *optional*, defaults to 0.8):
+                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+                noise will be maximum and the denoising process will run for the full number of iterations specified in
+                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference. This parameter will be modulated by `strength`.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        return self.__call__(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            image=image,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            strength=strength,
+            num_images_per_prompt=num_images_per_prompt,
+            eta=eta,
+            generator=generator,
+            max_embeddings_multiples=max_embeddings_multiples,
+            output_type=output_type,
+            return_dict=return_dict,
+            callback=callback,
+            is_cancelled_callback=is_cancelled_callback,
+            callback_steps=callback_steps,
+        )
+
+    def inpaint(
+        self,
+        image: Union[torch.FloatTensor, PIL.Image.Image],
+        mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        strength: float = 0.8,
+        num_inference_steps: Optional[int] = 50,
+        guidance_scale: Optional[float] = 7.5,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: Optional[float] = 0.0,
+        generator: Optional[torch.Generator] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function for inpaint.
+        Args:
+            image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, that will be used as the starting point for the
+                process. This is the image whose masked region will be inpainted.
+            mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            strength (`float`, *optional*, defaults to 0.8):
+                Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+                is 1, the denoising process will be run on the masked area for the full number of iterations specified
+                in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
+                noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+                the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        return self.__call__(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            image=image,
+            mask_image=mask_image,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            strength=strength,
+            num_images_per_prompt=num_images_per_prompt,
+            eta=eta,
+            generator=generator,
+            max_embeddings_multiples=max_embeddings_multiples,
+            output_type=output_type,
+            return_dict=return_dict,
+            callback=callback,
+            is_cancelled_callback=is_cancelled_callback,
+            callback_steps=callback_steps,
+        )
diff --git a/external/llite/library/model_util.py b/external/llite/library/model_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f9b408f7117fc78293d178324704782b7423541
--- /dev/null
+++ b/external/llite/library/model_util.py
@@ -0,0 +1,1350 @@
+# v1: split from train_db_fixed.py.
+# v2: support safetensors
+
+import math
+import os
+import torch
+try:
+    import intel_extension_for_pytorch as ipex
+    if torch.xpu.is_available():
+        from library.ipex import ipex_init
+        ipex_init()
+except Exception:
+    pass
+import diffusers
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
+from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline  # , UNet2DConditionModel
+from safetensors.torch import load_file, save_file
+from external.llite.library.original_unet import UNet2DConditionModel
+
+# DiffUsers版StableDiffusionのモデルパラメータ
+NUM_TRAIN_TIMESTEPS = 1000
+BETA_START = 0.00085
+BETA_END = 0.0120
+
+UNET_PARAMS_MODEL_CHANNELS = 320
+UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
+UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
+UNET_PARAMS_IMAGE_SIZE = 64  # fixed from old invalid value `32`
+UNET_PARAMS_IN_CHANNELS = 4
+UNET_PARAMS_OUT_CHANNELS = 4
+UNET_PARAMS_NUM_RES_BLOCKS = 2
+UNET_PARAMS_CONTEXT_DIM = 768
+UNET_PARAMS_NUM_HEADS = 8
+# UNET_PARAMS_USE_LINEAR_PROJECTION = False
+
+VAE_PARAMS_Z_CHANNELS = 4
+VAE_PARAMS_RESOLUTION = 256
+VAE_PARAMS_IN_CHANNELS = 3
+VAE_PARAMS_OUT_CH = 3
+VAE_PARAMS_CH = 128
+VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
+VAE_PARAMS_NUM_RES_BLOCKS = 2
+
+# V2
+V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
+V2_UNET_PARAMS_CONTEXT_DIM = 1024
+# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
+
+# Diffusersの設定を読み込むための参照モデル
+DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
+DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
+
+
+# region StableDiffusion->Diffusersの変換コード
+# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+    """
+    Removes segments. Positive values shave the first segments, negative shave the last segments.
+    """
+    if n_shave_prefix_segments >= 0:
+        return ".".join(path.split(".")[n_shave_prefix_segments:])
+    else:
+        return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside resnets to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item.replace("in_layers.0", "norm1")
+        new_item = new_item.replace("in_layers.2", "conv1")
+
+        new_item = new_item.replace("out_layers.0", "norm2")
+        new_item = new_item.replace("out_layers.3", "conv2")
+
+        new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+        new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside resnets to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item
+
+        new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside attentions to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item
+
+        #         new_item = new_item.replace('norm.weight', 'group_norm.weight')
+        #         new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+        #         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+        #         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+        #         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+    """
+    Updates paths inside attentions to the new naming scheme (local renaming)
+    """
+    mapping = []
+    for old_item in old_list:
+        new_item = old_item
+
+        new_item = new_item.replace("norm.weight", "group_norm.weight")
+        new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+        if diffusers.__version__ < "0.17.0":
+            new_item = new_item.replace("q.weight", "query.weight")
+            new_item = new_item.replace("q.bias", "query.bias")
+
+            new_item = new_item.replace("k.weight", "key.weight")
+            new_item = new_item.replace("k.bias", "key.bias")
+
+            new_item = new_item.replace("v.weight", "value.weight")
+            new_item = new_item.replace("v.bias", "value.bias")
+
+            new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
+            new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
+        else:
+            new_item = new_item.replace("q.weight", "to_q.weight")
+            new_item = new_item.replace("q.bias", "to_q.bias")
+
+            new_item = new_item.replace("k.weight", "to_k.weight")
+            new_item = new_item.replace("k.bias", "to_k.bias")
+
+            new_item = new_item.replace("v.weight", "to_v.weight")
+            new_item = new_item.replace("v.bias", "to_v.bias")
+
+            new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+            new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+        mapping.append({"old": old_item, "new": new_item})
+
+    return mapping
+
+
+def assign_to_checkpoint(
+    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+    """
+    This does the final conversion step: take locally converted weights and apply a global renaming
+    to them. It splits attention layers, and takes into account additional replacements
+    that may arise.
+
+    Assigns the weights to the new checkpoint.
+    """
+    assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+    # Splits the attention layers into three variables.
+    if attention_paths_to_split is not None:
+        for path, path_map in attention_paths_to_split.items():
+            old_tensor = old_checkpoint[path]
+            channels = old_tensor.shape[0] // 3
+
+            target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+            num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+            old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+            query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+            checkpoint[path_map["query"]] = query.reshape(target_shape)
+            checkpoint[path_map["key"]] = key.reshape(target_shape)
+            checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+    for path in paths:
+        new_path = path["new"]
+
+        # These have already been assigned
+        if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+            continue
+
+        # Global renaming happens here
+        new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+        new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+        new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+        if additional_replacements is not None:
+            for replacement in additional_replacements:
+                new_path = new_path.replace(replacement["old"], replacement["new"])
+
+        # proj_attn.weight has to be converted from conv 1D to linear
+        reshaping = False
+        if diffusers.__version__ < "0.17.0":
+            if "proj_attn.weight" in new_path:
+                reshaping = True
+        else:
+            if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
+                reshaping = True
+
+        if reshaping:
+            checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
+        else:
+            checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+    keys = list(checkpoint.keys())
+    attn_keys = ["query.weight", "key.weight", "value.weight"]
+    for key in keys:
+        if ".".join(key.split(".")[-2:]) in attn_keys:
+            if checkpoint[key].ndim > 2:
+                checkpoint[key] = checkpoint[key][:, :, 0, 0]
+        elif "proj_attn.weight" in key:
+            if checkpoint[key].ndim > 2:
+                checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def linear_transformer_to_conv(checkpoint):
+    keys = list(checkpoint.keys())
+    tf_keys = ["proj_in.weight", "proj_out.weight"]
+    for key in keys:
+        if ".".join(key.split(".")[-2:]) in tf_keys:
+            if checkpoint[key].ndim == 2:
+                checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
+
+
+def convert_ldm_unet_checkpoint(v2, checkpoint, config):
+    """
+    Takes a state dict and a config, and returns a converted checkpoint.
+    """
+
+    # extract state_dict for UNet
+    unet_state_dict = {}
+    unet_key = "model.diffusion_model."
+    keys = list(checkpoint.keys())
+    for key in keys:
+        if key.startswith(unet_key):
+            unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+    new_checkpoint = {}
+
+    new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+    new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+    new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+    new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+    new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+    new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+    new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+    new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+    new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+    new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+    # Retrieves the keys for the input blocks only
+    num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+    input_blocks = {
+        layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
+    }
+
+    # Retrieves the keys for the middle blocks only
+    num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+    middle_blocks = {
+        layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
+    }
+
+    # Retrieves the keys for the output blocks only
+    num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+    output_blocks = {
+        layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
+    }
+
+    for i in range(1, num_input_blocks):
+        block_id = (i - 1) // (config["layers_per_block"] + 1)
+        layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+        resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
+        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+        if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+                f"input_blocks.{i}.0.op.weight"
+            )
+            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
+
+        paths = renew_resnet_paths(resnets)
+        meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+        assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+        if len(attentions):
+            paths = renew_attention_paths(attentions)
+            meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+            assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+    resnet_0 = middle_blocks[0]
+    attentions = middle_blocks[1]
+    resnet_1 = middle_blocks[2]
+
+    resnet_0_paths = renew_resnet_paths(resnet_0)
+    assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+    resnet_1_paths = renew_resnet_paths(resnet_1)
+    assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+    attentions_paths = renew_attention_paths(attentions)
+    meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+    assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+    for i in range(num_output_blocks):
+        block_id = i // (config["layers_per_block"] + 1)
+        layer_in_block_id = i % (config["layers_per_block"] + 1)
+        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+        output_block_list = {}
+
+        for layer in output_block_layers:
+            layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+            if layer_id in output_block_list:
+                output_block_list[layer_id].append(layer_name)
+            else:
+                output_block_list[layer_id] = [layer_name]
+
+        if len(output_block_list) > 1:
+            resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+            attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+            resnet_0_paths = renew_resnet_paths(resnets)
+            paths = renew_resnet_paths(resnets)
+
+            meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+            assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+            # オリジナル:
+            # if ["conv.weight", "conv.bias"] in output_block_list.values():
+            #   index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
+
+            # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
+            for l in output_block_list.values():
+                l.sort()
+
+            if ["conv.bias", "conv.weight"] in output_block_list.values():
+                index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+                    f"output_blocks.{i}.{index}.conv.bias"
+                ]
+                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+                    f"output_blocks.{i}.{index}.conv.weight"
+                ]
+
+                # Clear attentions as they have been attributed above.
+                if len(attentions) == 2:
+                    attentions = []
+
+            if len(attentions):
+                paths = renew_attention_paths(attentions)
+                meta_path = {
+                    "old": f"output_blocks.{i}.1",
+                    "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+                }
+                assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+        else:
+            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+            for path in resnet_0_paths:
+                old_path = ".".join(["output_blocks", str(i), path["old"]])
+                new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+                new_checkpoint[new_path] = unet_state_dict[old_path]
+
+    # SDのv2では1*1のconv2dがlinearに変わっている
+    # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
+    if v2 and not config.get("use_linear_projection", False):
+        linear_transformer_to_conv(new_checkpoint)
+
+    return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+    # extract state dict for VAE
+    vae_state_dict = {}
+    vae_key = "first_stage_model."
+    keys = list(checkpoint.keys())
+    for key in keys:
+        if key.startswith(vae_key):
+            vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+    # if len(vae_state_dict) == 0:
+    #   # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
+    #   vae_state_dict = checkpoint
+
+    new_checkpoint = {}
+
+    new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+    new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+    new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+    new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+    new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+    new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+    new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+    new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+    new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+    new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+    new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+    new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+    new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+    new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+    new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+    new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+    # Retrieves the keys for the encoder down blocks only
+    num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+    down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
+
+    # Retrieves the keys for the decoder up blocks only
+    num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+    up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+    for i in range(num_down_blocks):
+        resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+        if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+            new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+                f"encoder.down.{i}.downsample.conv.weight"
+            )
+            new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+                f"encoder.down.{i}.downsample.conv.bias"
+            )
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+    num_mid_res_blocks = 2
+    for i in range(1, num_mid_res_blocks + 1):
+        resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+    paths = renew_vae_attention_paths(mid_attentions)
+    meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+    conv_attn_to_linear(new_checkpoint)
+
+    for i in range(num_up_blocks):
+        block_id = num_up_blocks - 1 - i
+        resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
+
+        if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+            new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+                f"decoder.up.{block_id}.upsample.conv.weight"
+            ]
+            new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+                f"decoder.up.{block_id}.upsample.conv.bias"
+            ]
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+    num_mid_res_blocks = 2
+    for i in range(1, num_mid_res_blocks + 1):
+        resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+        paths = renew_vae_resnet_paths(resnets)
+        meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+    mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+    paths = renew_vae_attention_paths(mid_attentions)
+    meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+    conv_attn_to_linear(new_checkpoint)
+    return new_checkpoint
+
+
+def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
+    """
+    Creates a config for the diffusers based on the config of the LDM model.
+    """
+    # unet_params = original_config.model.params.unet_config.params
+
+    block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
+
+    down_block_types = []
+    resolution = 1
+    for i in range(len(block_out_channels)):
+        block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
+        down_block_types.append(block_type)
+        if i != len(block_out_channels) - 1:
+            resolution *= 2
+
+    up_block_types = []
+    for i in range(len(block_out_channels)):
+        block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
+        up_block_types.append(block_type)
+        resolution //= 2
+
+    config = dict(
+        sample_size=UNET_PARAMS_IMAGE_SIZE,
+        in_channels=UNET_PARAMS_IN_CHANNELS,
+        out_channels=UNET_PARAMS_OUT_CHANNELS,
+        down_block_types=tuple(down_block_types),
+        up_block_types=tuple(up_block_types),
+        block_out_channels=tuple(block_out_channels),
+        layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
+        cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
+        attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
+        # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
+    )
+    if v2 and use_linear_projection_in_v2:
+        config["use_linear_projection"] = True
+
+    return config
+
+
+def create_vae_diffusers_config():
+    """
+    Creates a config for the diffusers based on the config of the LDM model.
+    """
+    # vae_params = original_config.model.params.first_stage_config.params.ddconfig
+    # _ = original_config.model.params.first_stage_config.params.embed_dim
+    block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
+    down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+    up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+    config = dict(
+        sample_size=VAE_PARAMS_RESOLUTION,
+        in_channels=VAE_PARAMS_IN_CHANNELS,
+        out_channels=VAE_PARAMS_OUT_CH,
+        down_block_types=tuple(down_block_types),
+        up_block_types=tuple(up_block_types),
+        block_out_channels=tuple(block_out_channels),
+        latent_channels=VAE_PARAMS_Z_CHANNELS,
+        layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
+    )
+    return config
+
+
+def convert_ldm_clip_checkpoint_v1(checkpoint):
+    keys = list(checkpoint.keys())
+    text_model_dict = {}
+    for key in keys:
+        if key.startswith("cond_stage_model.transformer"):
+            text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
+
+    # support checkpoint without position_ids (invalid checkpoint)
+    if "text_model.embeddings.position_ids" not in text_model_dict:
+        text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)  # 77 is the max length of the text
+
+    return text_model_dict
+
+
+def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
+    # 嫌になるくらい違うぞ!
+    def convert_key(key):
+        if not key.startswith("cond_stage_model"):
+            return None
+
+        # common conversion
+        key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
+        key = key.replace("cond_stage_model.model.", "text_model.")
+
+        if "resblocks" in key:
+            # resblocks conversion
+            key = key.replace(".resblocks.", ".layers.")
+            if ".ln_" in key:
+                key = key.replace(".ln_", ".layer_norm")
+            elif ".mlp." in key:
+                key = key.replace(".c_fc.", ".fc1.")
+                key = key.replace(".c_proj.", ".fc2.")
+            elif ".attn.out_proj" in key:
+                key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
+            elif ".attn.in_proj" in key:
+                key = None  # 特殊なので後で処理する
+            else:
+                raise ValueError(f"unexpected key in SD: {key}")
+        elif ".positional_embedding" in key:
+            key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
+        elif ".text_projection" in key:
+            key = None  # 使われない???
+        elif ".logit_scale" in key:
+            key = None  # 使われない???
+        elif ".token_embedding" in key:
+            key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
+        elif ".ln_final" in key:
+            key = key.replace(".ln_final", ".final_layer_norm")
+        return key
+
+    keys = list(checkpoint.keys())
+    new_sd = {}
+    for key in keys:
+        # remove resblocks 23
+        if ".resblocks.23." in key:
+            continue
+        new_key = convert_key(key)
+        if new_key is None:
+            continue
+        new_sd[new_key] = checkpoint[key]
+
+    # attnの変換
+    for key in keys:
+        if ".resblocks.23." in key:
+            continue
+        if ".resblocks" in key and ".attn.in_proj_" in key:
+            # 三つに分割
+            values = torch.chunk(checkpoint[key], 3)
+
+            key_suffix = ".weight" if "weight" in key else ".bias"
+            key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
+            key_pfx = key_pfx.replace("_weight", "")
+            key_pfx = key_pfx.replace("_bias", "")
+            key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
+            new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
+            new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
+            new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
+
+    # rename or add position_ids
+    ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
+    if ANOTHER_POSITION_IDS_KEY in new_sd:
+        # waifu diffusion v1.4
+        position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
+        del new_sd[ANOTHER_POSITION_IDS_KEY]
+    else:
+        position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
+
+    new_sd["text_model.embeddings.position_ids"] = position_ids
+    return new_sd
+
+
+# endregion
+
+
+# region Diffusers->StableDiffusion の変換コード
+# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
+
+
+def conv_transformer_to_linear(checkpoint):
+    keys = list(checkpoint.keys())
+    tf_keys = ["proj_in.weight", "proj_out.weight"]
+    for key in keys:
+        if ".".join(key.split(".")[-2:]) in tf_keys:
+            if checkpoint[key].ndim > 2:
+                checkpoint[key] = checkpoint[key][:, :, 0, 0]
+
+
+def convert_unet_state_dict_to_sd(v2, unet_state_dict):
+    unet_conversion_map = [
+        # (stable-diffusion, HF Diffusers)
+        ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+        ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+        ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+        ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+        ("input_blocks.0.0.weight", "conv_in.weight"),
+        ("input_blocks.0.0.bias", "conv_in.bias"),
+        ("out.0.weight", "conv_norm_out.weight"),
+        ("out.0.bias", "conv_norm_out.bias"),
+        ("out.2.weight", "conv_out.weight"),
+        ("out.2.bias", "conv_out.bias"),
+    ]
+
+    unet_conversion_map_resnet = [
+        # (stable-diffusion, HF Diffusers)
+        ("in_layers.0", "norm1"),
+        ("in_layers.2", "conv1"),
+        ("out_layers.0", "norm2"),
+        ("out_layers.3", "conv2"),
+        ("emb_layers.1", "time_emb_proj"),
+        ("skip_connection", "conv_shortcut"),
+    ]
+
+    unet_conversion_map_layer = []
+    for i in range(4):
+        # loop over downblocks/upblocks
+
+        for j in range(2):
+            # loop over resnets/attentions for downblocks
+            hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+            sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+            unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+            if i < 3:
+                # no attention layers in down_blocks.3
+                hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+                sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+                unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+        for j in range(3):
+            # loop over resnets/attentions for upblocks
+            hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+            sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+            unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+            if i > 0:
+                # no attention layers in up_blocks.0
+                hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+                sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+                unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+        if i < 3:
+            # no downsample in down_blocks.3
+            hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+            sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+            unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+            # no upsample in up_blocks.3
+            hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+            sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+            unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+    hf_mid_atn_prefix = "mid_block.attentions.0."
+    sd_mid_atn_prefix = "middle_block.1."
+    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+    for j in range(2):
+        hf_mid_res_prefix = f"mid_block.resnets.{j}."
+        sd_mid_res_prefix = f"middle_block.{2*j}."
+        unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+    # buyer beware: this is a *brittle* function,
+    # and correct output requires that all of these pieces interact in
+    # the exact order in which I have arranged them.
+    mapping = {k: k for k in unet_state_dict.keys()}
+    for sd_name, hf_name in unet_conversion_map:
+        mapping[hf_name] = sd_name
+    for k, v in mapping.items():
+        if "resnets" in k:
+            for sd_part, hf_part in unet_conversion_map_resnet:
+                v = v.replace(hf_part, sd_part)
+            mapping[k] = v
+    for k, v in mapping.items():
+        for sd_part, hf_part in unet_conversion_map_layer:
+            v = v.replace(hf_part, sd_part)
+        mapping[k] = v
+    new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+
+    if v2:
+        conv_transformer_to_linear(new_state_dict)
+
+    return new_state_dict
+
+
+def controlnet_conversion_map():
+    unet_conversion_map = [
+        ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+        ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+        ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+        ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+        ("input_blocks.0.0.weight", "conv_in.weight"),
+        ("input_blocks.0.0.bias", "conv_in.bias"),
+        ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
+        ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
+    ]
+
+    unet_conversion_map_resnet = [
+        ("in_layers.0", "norm1"),
+        ("in_layers.2", "conv1"),
+        ("out_layers.0", "norm2"),
+        ("out_layers.3", "conv2"),
+        ("emb_layers.1", "time_emb_proj"),
+        ("skip_connection", "conv_shortcut"),
+    ]
+
+    unet_conversion_map_layer = []
+    for i in range(4):
+        for j in range(2):
+            hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+            sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+            unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+            if i < 3:
+                hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+                sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+                unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+        if i < 3:
+            hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+            sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+            unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+    hf_mid_atn_prefix = "mid_block.attentions.0."
+    sd_mid_atn_prefix = "middle_block.1."
+    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+    for j in range(2):
+        hf_mid_res_prefix = f"mid_block.resnets.{j}."
+        sd_mid_res_prefix = f"middle_block.{2*j}."
+        unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+    controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
+    for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
+        hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
+        sd_prefix = f"input_hint_block.{i*2}."
+        unet_conversion_map_layer.append((sd_prefix, hf_prefix))
+
+    for i in range(12):
+        hf_prefix = f"controlnet_down_blocks.{i}."
+        sd_prefix = f"zero_convs.{i}.0."
+        unet_conversion_map_layer.append((sd_prefix, hf_prefix))
+
+    return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
+
+
+def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
+    unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
+
+    mapping = {k: k for k in controlnet_state_dict.keys()}
+    for sd_name, diffusers_name in unet_conversion_map:
+        mapping[diffusers_name] = sd_name
+    for k, v in mapping.items():
+        if "resnets" in k:
+            for sd_part, diffusers_part in unet_conversion_map_resnet:
+                v = v.replace(diffusers_part, sd_part)
+            mapping[k] = v
+    for k, v in mapping.items():
+        for sd_part, diffusers_part in unet_conversion_map_layer:
+            v = v.replace(diffusers_part, sd_part)
+        mapping[k] = v
+    new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
+    return new_state_dict
+
+
+def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
+    unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
+
+    mapping = {k: k for k in controlnet_state_dict.keys()}
+    for sd_name, diffusers_name in unet_conversion_map:
+        mapping[sd_name] = diffusers_name
+    for k, v in mapping.items():
+        for sd_part, diffusers_part in unet_conversion_map_layer:
+            v = v.replace(sd_part, diffusers_part)
+        mapping[k] = v
+    for k, v in mapping.items():
+        if "resnets" in v:
+            for sd_part, diffusers_part in unet_conversion_map_resnet:
+                v = v.replace(sd_part, diffusers_part)
+            mapping[k] = v
+    new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
+    return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+
+def reshape_weight_for_sd(w):
+    # convert HF linear weights to SD conv2d weights
+    return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+    vae_conversion_map = [
+        # (stable-diffusion, HF Diffusers)
+        ("nin_shortcut", "conv_shortcut"),
+        ("norm_out", "conv_norm_out"),
+        ("mid.attn_1.", "mid_block.attentions.0."),
+    ]
+
+    for i in range(4):
+        # down_blocks have two resnets
+        for j in range(2):
+            hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+            sd_down_prefix = f"encoder.down.{i}.block.{j}."
+            vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+        if i < 3:
+            hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+            sd_downsample_prefix = f"down.{i}.downsample."
+            vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+            hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+            sd_upsample_prefix = f"up.{3-i}.upsample."
+            vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+        # up_blocks have three resnets
+        # also, up blocks in hf are numbered in reverse from sd
+        for j in range(3):
+            hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+            sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+            vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+    # this part accounts for mid blocks in both the encoder and the decoder
+    for i in range(2):
+        hf_mid_res_prefix = f"mid_block.resnets.{i}."
+        sd_mid_res_prefix = f"mid.block_{i+1}."
+        vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+    if diffusers.__version__ < "0.17.0":
+        vae_conversion_map_attn = [
+            # (stable-diffusion, HF Diffusers)
+            ("norm.", "group_norm."),
+            ("q.", "query."),
+            ("k.", "key."),
+            ("v.", "value."),
+            ("proj_out.", "proj_attn."),
+        ]
+    else:
+        vae_conversion_map_attn = [
+            # (stable-diffusion, HF Diffusers)
+            ("norm.", "group_norm."),
+            ("q.", "to_q."),
+            ("k.", "to_k."),
+            ("v.", "to_v."),
+            ("proj_out.", "to_out.0."),
+        ]
+
+    mapping = {k: k for k in vae_state_dict.keys()}
+    for k, v in mapping.items():
+        for sd_part, hf_part in vae_conversion_map:
+            v = v.replace(hf_part, sd_part)
+        mapping[k] = v
+    for k, v in mapping.items():
+        if "attentions" in k:
+            for sd_part, hf_part in vae_conversion_map_attn:
+                v = v.replace(hf_part, sd_part)
+            mapping[k] = v
+    new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+    weights_to_convert = ["q", "k", "v", "proj_out"]
+    for k, v in new_state_dict.items():
+        for weight_name in weights_to_convert:
+            if f"mid.attn_1.{weight_name}.weight" in k:
+                # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
+                new_state_dict[k] = reshape_weight_for_sd(v)
+
+    return new_state_dict
+
+
+# endregion
+
+# region 自作のモデル読み書きなど
+
+
+def is_safetensors(path):
+    return os.path.splitext(path)[1].lower() == ".safetensors"
+
+
+def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
+    # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
+    TEXT_ENCODER_KEY_REPLACEMENTS = [
+        ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
+        ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
+        ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
+    ]
+
+    if is_safetensors(ckpt_path):
+        checkpoint = None
+        state_dict = load_file(ckpt_path)  # , device) # may causes error
+    else:
+        checkpoint = torch.load(ckpt_path, map_location=device)
+        if "state_dict" in checkpoint:
+            state_dict = checkpoint["state_dict"]
+        else:
+            state_dict = checkpoint
+            checkpoint = None
+
+    key_reps = []
+    for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
+        for key in state_dict.keys():
+            if key.startswith(rep_from):
+                new_key = rep_to + key[len(rep_from) :]
+                key_reps.append((key, new_key))
+
+    for key, new_key in key_reps:
+        state_dict[new_key] = state_dict[key]
+        del state_dict[key]
+
+    return checkpoint, state_dict
+
+
+# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
+def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
+    _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
+
+    # Convert the UNet2DConditionModel model.
+    unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
+    converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
+
+    unet = UNet2DConditionModel(**unet_config).to(device)
+    info = unet.load_state_dict(converted_unet_checkpoint)
+    print("loading u-net:", info)
+
+    # Convert the VAE model.
+    vae_config = create_vae_diffusers_config()
+    converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
+
+    vae = AutoencoderKL(**vae_config).to(device)
+    info = vae.load_state_dict(converted_vae_checkpoint)
+    print("loading vae:", info)
+
+    # convert text_model
+    if v2:
+        converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
+        cfg = CLIPTextConfig(
+            vocab_size=49408,
+            hidden_size=1024,
+            intermediate_size=4096,
+            num_hidden_layers=23,
+            num_attention_heads=16,
+            max_position_embeddings=77,
+            hidden_act="gelu",
+            layer_norm_eps=1e-05,
+            dropout=0.0,
+            attention_dropout=0.0,
+            initializer_range=0.02,
+            initializer_factor=1.0,
+            pad_token_id=1,
+            bos_token_id=0,
+            eos_token_id=2,
+            model_type="clip_text_model",
+            projection_dim=512,
+            torch_dtype="float32",
+            transformers_version="4.25.0.dev0",
+        )
+        text_model = CLIPTextModel._from_config(cfg)
+        info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+    else:
+        converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
+
+        # logging.set_verbosity_error()  # don't show annoying warning
+        # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
+        # logging.set_verbosity_warning()
+        # print(f"config: {text_model.config}")
+        cfg = CLIPTextConfig(
+            vocab_size=49408,
+            hidden_size=768,
+            intermediate_size=3072,
+            num_hidden_layers=12,
+            num_attention_heads=12,
+            max_position_embeddings=77,
+            hidden_act="quick_gelu",
+            layer_norm_eps=1e-05,
+            dropout=0.0,
+            attention_dropout=0.0,
+            initializer_range=0.02,
+            initializer_factor=1.0,
+            pad_token_id=1,
+            bos_token_id=0,
+            eos_token_id=2,
+            model_type="clip_text_model",
+            projection_dim=768,
+            torch_dtype="float32",
+        )
+        text_model = CLIPTextModel._from_config(cfg)
+        info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+    print("loading text encoder:", info)
+
+    return text_model, vae, unet
+
+
+def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
+    # only for reference
+    version_str = "sd"
+    if v2:
+        version_str += "_v2"
+    else:
+        version_str += "_v1"
+    if v_parameterization:
+        version_str += "_v"
+    return version_str
+
+
+def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
+    def convert_key(key):
+        # position_idsの除去
+        if ".position_ids" in key:
+            return None
+
+        # common
+        key = key.replace("text_model.encoder.", "transformer.")
+        key = key.replace("text_model.", "")
+        if "layers" in key:
+            # resblocks conversion
+            key = key.replace(".layers.", ".resblocks.")
+            if ".layer_norm" in key:
+                key = key.replace(".layer_norm", ".ln_")
+            elif ".mlp." in key:
+                key = key.replace(".fc1.", ".c_fc.")
+                key = key.replace(".fc2.", ".c_proj.")
+            elif ".self_attn.out_proj" in key:
+                key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
+            elif ".self_attn." in key:
+                key = None  # 特殊なので後で処理する
+            else:
+                raise ValueError(f"unexpected key in DiffUsers model: {key}")
+        elif ".position_embedding" in key:
+            key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
+        elif ".token_embedding" in key:
+            key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
+        elif "final_layer_norm" in key:
+            key = key.replace("final_layer_norm", "ln_final")
+        return key
+
+    keys = list(checkpoint.keys())
+    new_sd = {}
+    for key in keys:
+        new_key = convert_key(key)
+        if new_key is None:
+            continue
+        new_sd[new_key] = checkpoint[key]
+
+    # attnの変換
+    for key in keys:
+        if "layers" in key and "q_proj" in key:
+            # 三つを結合
+            key_q = key
+            key_k = key.replace("q_proj", "k_proj")
+            key_v = key.replace("q_proj", "v_proj")
+
+            value_q = checkpoint[key_q]
+            value_k = checkpoint[key_k]
+            value_v = checkpoint[key_v]
+            value = torch.cat([value_q, value_k, value_v])
+
+            new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
+            new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
+            new_sd[new_key] = value
+
+    # 最後の層などを捏造するか
+    if make_dummy_weights:
+        print("make dummy weights for resblock.23, text_projection and logit scale.")
+        keys = list(new_sd.keys())
+        for key in keys:
+            if key.startswith("transformer.resblocks.22."):
+                new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone()  # copyしないとsafetensorsの保存で落ちる
+
+        # Diffusersに含まれない重みを作っておく
+        new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
+        new_sd["logit_scale"] = torch.tensor(1)
+
+    return new_sd
+
+
+def save_stable_diffusion_checkpoint(
+    v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
+):
+    if ckpt_path is not None:
+        # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
+        checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
+        if checkpoint is None:  # safetensors または state_dictのckpt
+            checkpoint = {}
+            strict = False
+        else:
+            strict = True
+        if "state_dict" in state_dict:
+            del state_dict["state_dict"]
+    else:
+        # 新しく作る
+        assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
+        checkpoint = {}
+        state_dict = {}
+        strict = False
+
+    def update_sd(prefix, sd):
+        for k, v in sd.items():
+            key = prefix + k
+            assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
+            if save_dtype is not None:
+                v = v.detach().clone().to("cpu").to(save_dtype)
+            state_dict[key] = v
+
+    # Convert the UNet model
+    unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
+    update_sd("model.diffusion_model.", unet_state_dict)
+
+    # Convert the text encoder model
+    if v2:
+        make_dummy = ckpt_path is None  # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
+        text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
+        update_sd("cond_stage_model.model.", text_enc_dict)
+    else:
+        text_enc_dict = text_encoder.state_dict()
+        update_sd("cond_stage_model.transformer.", text_enc_dict)
+
+    # Convert the VAE
+    if vae is not None:
+        vae_dict = convert_vae_state_dict(vae.state_dict())
+        update_sd("first_stage_model.", vae_dict)
+
+    # Put together new checkpoint
+    key_count = len(state_dict.keys())
+    new_ckpt = {"state_dict": state_dict}
+
+    # epoch and global_step are sometimes not int
+    try:
+        if "epoch" in checkpoint:
+            epochs += checkpoint["epoch"]
+        if "global_step" in checkpoint:
+            steps += checkpoint["global_step"]
+    except:
+        pass
+
+    new_ckpt["epoch"] = epochs
+    new_ckpt["global_step"] = steps
+
+    if is_safetensors(output_file):
+        # TODO Tensor以外のdictの値を削除したほうがいいか
+        save_file(state_dict, output_file, metadata)
+    else:
+        torch.save(new_ckpt, output_file)
+
+    return key_count
+
+
+def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
+    if pretrained_model_name_or_path is None:
+        # load default settings for v1/v2
+        if v2:
+            pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
+        else:
+            pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
+
+    scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
+    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
+    if vae is None:
+        vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+
+    pipeline = StableDiffusionPipeline(
+        unet=unet,
+        text_encoder=text_encoder,
+        vae=vae,
+        scheduler=scheduler,
+        tokenizer=tokenizer,
+        safety_checker=None,
+        feature_extractor=None,
+        requires_safety_checker=None,
+    )
+    pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
+
+
+VAE_PREFIX = "first_stage_model."
+
+
+def load_vae(vae_id, dtype):
+    print(f"load VAE: {vae_id}")
+    if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
+        # Diffusers local/remote
+        try:
+            vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
+        except EnvironmentError as e:
+            print(f"exception occurs in loading vae: {e}")
+            print("retry with subfolder='vae'")
+            vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
+        return vae
+
+    # local
+    vae_config = create_vae_diffusers_config()
+
+    if vae_id.endswith(".bin"):
+        # SD 1.5 VAE on Huggingface
+        converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
+    else:
+        # StableDiffusion
+        vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
+        vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
+
+        # vae only or full model
+        full_model = False
+        for vae_key in vae_sd:
+            if vae_key.startswith(VAE_PREFIX):
+                full_model = True
+                break
+        if not full_model:
+            sd = {}
+            for key, value in vae_sd.items():
+                sd[VAE_PREFIX + key] = value
+            vae_sd = sd
+            del sd
+
+        # Convert the VAE model.
+        converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
+
+    vae = AutoencoderKL(**vae_config)
+    vae.load_state_dict(converted_vae_checkpoint)
+    return vae
+
+
+# endregion
+
+
+def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
+    max_width, max_height = max_reso
+    max_area = max_width * max_height
+
+    resos = set()
+
+    width = int(math.sqrt(max_area) // divisible) * divisible
+    resos.add((width, width))
+
+    width = min_size
+    while width <= max_size:
+        height = min(max_size, int((max_area // width) // divisible) * divisible)
+        if height >= min_size:
+            resos.add((width, height))
+            resos.add((height, width))
+
+        # # make additional resos
+        # if width >= height and width - divisible >= min_size:
+        #   resos.add((width - divisible, height))
+        #   resos.add((height, width - divisible))
+        # if height >= width and height - divisible >= min_size:
+        #   resos.add((width, height - divisible))
+        #   resos.add((height - divisible, width))
+
+        width += divisible
+
+    resos = list(resos)
+    resos.sort()
+    return resos
+
+
+if __name__ == "__main__":
+    resos = make_bucket_resolutions((512, 768))
+    print(len(resos))
+    print(resos)
+    aspect_ratios = [w / h for w, h in resos]
+    print(aspect_ratios)
+
+    ars = set()
+    for ar in aspect_ratios:
+        if ar in ars:
+            print("error! duplicate ar:", ar)
+        ars.add(ar)
diff --git a/external/llite/library/original_unet.py b/external/llite/library/original_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..00997e7c01eb8b023140330284ef32317f975e4c
--- /dev/null
+++ b/external/llite/library/original_unet.py
@@ -0,0 +1,1915 @@
+# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
+# 条件分岐等で不要な部分は削除している
+# コードの多くはDiffusersからコピーしている
+# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
+
+# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
+# Unnecessary parts are deleted by condition branching.
+# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
+
+"""
+v1.5とv2.1の相違点は
+- attention_head_dimがintかlist[int]か
+- cross_attention_dimが768か1024か
+- use_linear_projection: trueがない(=False, 1.5)かあるか
+- upcast_attentionがFalse(1.5)かTrue(2.1)か
+- (以下は多分無視していい)
+- sample_sizeが64か96か
+- dual_cross_attentionがあるかないか
+- num_class_embedsがあるかないか
+- only_cross_attentionがあるかないか
+
+v1.5
+{
+  "_class_name": "UNet2DConditionModel",
+  "_diffusers_version": "0.6.0",
+  "act_fn": "silu",
+  "attention_head_dim": 8,
+  "block_out_channels": [
+    320,
+    640,
+    1280,
+    1280
+  ],
+  "center_input_sample": false,
+  "cross_attention_dim": 768,
+  "down_block_types": [
+    "CrossAttnDownBlock2D",
+    "CrossAttnDownBlock2D",
+    "CrossAttnDownBlock2D",
+    "DownBlock2D"
+  ],
+  "downsample_padding": 1,
+  "flip_sin_to_cos": true,
+  "freq_shift": 0,
+  "in_channels": 4,
+  "layers_per_block": 2,
+  "mid_block_scale_factor": 1,
+  "norm_eps": 1e-05,
+  "norm_num_groups": 32,
+  "out_channels": 4,
+  "sample_size": 64,
+  "up_block_types": [
+    "UpBlock2D",
+    "CrossAttnUpBlock2D",
+    "CrossAttnUpBlock2D",
+    "CrossAttnUpBlock2D"
+  ]
+}
+
+v2.1
+{
+  "_class_name": "UNet2DConditionModel",
+  "_diffusers_version": "0.10.0.dev0",
+  "act_fn": "silu",
+  "attention_head_dim": [
+    5,
+    10,
+    20,
+    20
+  ],
+  "block_out_channels": [
+    320,
+    640,
+    1280,
+    1280
+  ],
+  "center_input_sample": false,
+  "cross_attention_dim": 1024,
+  "down_block_types": [
+    "CrossAttnDownBlock2D",
+    "CrossAttnDownBlock2D",
+    "CrossAttnDownBlock2D",
+    "DownBlock2D"
+  ],
+  "downsample_padding": 1,
+  "dual_cross_attention": false,
+  "flip_sin_to_cos": true,
+  "freq_shift": 0,
+  "in_channels": 4,
+  "layers_per_block": 2,
+  "mid_block_scale_factor": 1,
+  "norm_eps": 1e-05,
+  "norm_num_groups": 32,
+  "num_class_embeds": null,
+  "only_cross_attention": false,
+  "out_channels": 4,
+  "sample_size": 96,
+  "up_block_types": [
+    "UpBlock2D",
+    "CrossAttnUpBlock2D",
+    "CrossAttnUpBlock2D",
+    "CrossAttnUpBlock2D"
+  ],
+  "use_linear_projection": true,
+  "upcast_attention": true
+}
+"""
+
+import math
+from types import SimpleNamespace
+from typing import Dict, Optional, Tuple, Union
+import torch
+from torch import nn
+from torch.nn import functional as F
+from einops import rearrange
+
+BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
+TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
+TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
+IN_CHANNELS: int = 4
+OUT_CHANNELS: int = 4
+LAYERS_PER_BLOCK: int = 2
+LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
+TIME_EMBED_FLIP_SIN_TO_COS: bool = True
+TIME_EMBED_FREQ_SHIFT: int = 0
+NORM_GROUPS: int = 32
+NORM_EPS: float = 1e-5
+TRANSFORMER_NORM_NUM_GROUPS = 32
+
+DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
+UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
+
+
+# region memory efficient attention
+
+# FlashAttentionを使うCrossAttention
+# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
+# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
+
+# constants
+
+EPSILON = 1e-6
+
+# helper functions
+
+
+def exists(val):
+    return val is not None
+
+
+def default(val, d):
+    return val if exists(val) else d
+
+
+# flash attention forwards and backwards
+
+# https://arxiv.org/abs/2205.14135
+
+
+class FlashAttentionFunction(torch.autograd.Function):
+    @staticmethod
+    @torch.no_grad()
+    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
+        """Algorithm 2 in the paper"""
+
+        device = q.device
+        dtype = q.dtype
+        max_neg_value = -torch.finfo(q.dtype).max
+        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+        o = torch.zeros_like(q)
+        all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
+        all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
+
+        scale = q.shape[-1] ** -0.5
+
+        if not exists(mask):
+            mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
+        else:
+            mask = rearrange(mask, "b n -> b 1 1 n")
+            mask = mask.split(q_bucket_size, dim=-1)
+
+        row_splits = zip(
+            q.split(q_bucket_size, dim=-2),
+            o.split(q_bucket_size, dim=-2),
+            mask,
+            all_row_sums.split(q_bucket_size, dim=-2),
+            all_row_maxes.split(q_bucket_size, dim=-2),
+        )
+
+        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
+            q_start_index = ind * q_bucket_size - qk_len_diff
+
+            col_splits = zip(
+                k.split(k_bucket_size, dim=-2),
+                v.split(k_bucket_size, dim=-2),
+            )
+
+            for k_ind, (kc, vc) in enumerate(col_splits):
+                k_start_index = k_ind * k_bucket_size
+
+                attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
+
+                if exists(row_mask):
+                    attn_weights.masked_fill_(~row_mask, max_neg_value)
+
+                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
+                        q_start_index - k_start_index + 1
+                    )
+                    attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+                block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
+                attn_weights -= block_row_maxes
+                exp_weights = torch.exp(attn_weights)
+
+                if exists(row_mask):
+                    exp_weights.masked_fill_(~row_mask, 0.0)
+
+                block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
+
+                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
+
+                exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
+
+                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
+                exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
+
+                new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
+
+                oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
+
+                row_maxes.copy_(new_row_maxes)
+                row_sums.copy_(new_row_sums)
+
+        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
+        ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
+
+        return o
+
+    @staticmethod
+    @torch.no_grad()
+    def backward(ctx, do):
+        """Algorithm 4 in the paper"""
+
+        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
+        q, k, v, o, l, m = ctx.saved_tensors
+
+        device = q.device
+
+        max_neg_value = -torch.finfo(q.dtype).max
+        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+        dq = torch.zeros_like(q)
+        dk = torch.zeros_like(k)
+        dv = torch.zeros_like(v)
+
+        row_splits = zip(
+            q.split(q_bucket_size, dim=-2),
+            o.split(q_bucket_size, dim=-2),
+            do.split(q_bucket_size, dim=-2),
+            mask,
+            l.split(q_bucket_size, dim=-2),
+            m.split(q_bucket_size, dim=-2),
+            dq.split(q_bucket_size, dim=-2),
+        )
+
+        for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
+            q_start_index = ind * q_bucket_size - qk_len_diff
+
+            col_splits = zip(
+                k.split(k_bucket_size, dim=-2),
+                v.split(k_bucket_size, dim=-2),
+                dk.split(k_bucket_size, dim=-2),
+                dv.split(k_bucket_size, dim=-2),
+            )
+
+            for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
+                k_start_index = k_ind * k_bucket_size
+
+                attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
+
+                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
+                        q_start_index - k_start_index + 1
+                    )
+                    attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+                exp_attn_weights = torch.exp(attn_weights - mc)
+
+                if exists(row_mask):
+                    exp_attn_weights.masked_fill_(~row_mask, 0.0)
+
+                p = exp_attn_weights / lc
+
+                dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
+                dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
+
+                D = (doc * oc).sum(dim=-1, keepdims=True)
+                ds = p * scale * (dp - D)
+
+                dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
+                dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
+
+                dqc.add_(dq_chunk)
+                dkc.add_(dk_chunk)
+                dvc.add_(dv_chunk)
+
+        return dq, dk, dv, None, None, None, None
+
+
+# endregion
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+    return next(parameter.parameters()).dtype
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+    return next(parameter.parameters()).device
+
+
+def get_timestep_embedding(
+    timesteps: torch.Tensor,
+    embedding_dim: int,
+    flip_sin_to_cos: bool = False,
+    downscale_freq_shift: float = 1,
+    scale: float = 1,
+    max_period: int = 10000,
+):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+    embeddings. :return: an [N x dim] Tensor of positional embeddings.
+    """
+    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+    half_dim = embedding_dim // 2
+    exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
+    exponent = exponent / (half_dim - downscale_freq_shift)
+
+    emb = torch.exp(exponent)
+    emb = timesteps[:, None].float() * emb[None, :]
+
+    # scale embeddings
+    emb = scale * emb
+
+    # concat sine and cosine embeddings
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+    # flip sine and cosine embeddings
+    if flip_sin_to_cos:
+        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+    # zero pad
+    if embedding_dim % 2 == 1:
+        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+    return emb
+
+
+# Deep Shrink: We do not common this function, because minimize dependencies.
+def resize_like(x, target, mode="bicubic", align_corners=False):
+    org_dtype = x.dtype
+    if org_dtype == torch.bfloat16:
+        x = x.to(torch.float32)
+
+    if x.shape[-2:] != target.shape[-2:]:
+        if mode == "nearest":
+            x = F.interpolate(x, size=target.shape[-2:], mode=mode)
+        else:
+            x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
+
+    if org_dtype == torch.bfloat16:
+        x = x.to(org_dtype)
+    return x
+
+
+class SampleOutput:
+    def __init__(self, sample):
+        self.sample = sample
+
+
+class TimestepEmbedding(nn.Module):
+    def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
+        super().__init__()
+
+        self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+        self.act = None
+        if act_fn == "silu":
+            self.act = nn.SiLU()
+        elif act_fn == "mish":
+            self.act = nn.Mish()
+
+        if out_dim is not None:
+            time_embed_dim_out = out_dim
+        else:
+            time_embed_dim_out = time_embed_dim
+        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
+
+    def forward(self, sample):
+        sample = self.linear_1(sample)
+
+        if self.act is not None:
+            sample = self.act(sample)
+
+        sample = self.linear_2(sample)
+        return sample
+
+
+class Timesteps(nn.Module):
+    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+        super().__init__()
+        self.num_channels = num_channels
+        self.flip_sin_to_cos = flip_sin_to_cos
+        self.downscale_freq_shift = downscale_freq_shift
+
+    def forward(self, timesteps):
+        t_emb = get_timestep_embedding(
+            timesteps,
+            self.num_channels,
+            flip_sin_to_cos=self.flip_sin_to_cos,
+            downscale_freq_shift=self.downscale_freq_shift,
+        )
+        return t_emb
+
+
+class ResnetBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
+
+        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+        self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
+
+        self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
+        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+        # if non_linearity == "swish":
+        self.nonlinearity = lambda x: F.silu(x)
+
+        self.use_in_shortcut = self.in_channels != self.out_channels
+
+        self.conv_shortcut = None
+        if self.use_in_shortcut:
+            self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, input_tensor, temb):
+        hidden_states = input_tensor
+
+        hidden_states = self.norm1(hidden_states)
+        hidden_states = self.nonlinearity(hidden_states)
+
+        hidden_states = self.conv1(hidden_states)
+
+        temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+        hidden_states = hidden_states + temb
+
+        hidden_states = self.norm2(hidden_states)
+        hidden_states = self.nonlinearity(hidden_states)
+
+        hidden_states = self.conv2(hidden_states)
+
+        if self.conv_shortcut is not None:
+            input_tensor = self.conv_shortcut(input_tensor)
+
+        output_tensor = input_tensor + hidden_states
+
+        return output_tensor
+
+
+class DownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        add_downsample=True,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = False
+        resnets = []
+
+        for i in range(LAYERS_PER_BLOCK):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                )
+            )
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        pass
+
+    def set_use_sdpa(self, sdpa):
+        pass
+
+    def forward(self, hidden_states, temb=None):
+        output_states = ()
+
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+            else:
+                hidden_states = resnet(hidden_states, temb)
+
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class Downsample2D(nn.Module):
+    def __init__(self, channels, out_channels):
+        super().__init__()
+
+        self.channels = channels
+        self.out_channels = out_channels
+
+        self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
+
+    def forward(self, hidden_states):
+        assert hidden_states.shape[1] == self.channels
+        hidden_states = self.conv(hidden_states)
+
+        return hidden_states
+
+
+class CrossAttention(nn.Module):
+    def __init__(
+        self,
+        query_dim: int,
+        cross_attention_dim: Optional[int] = None,
+        heads: int = 8,
+        dim_head: int = 64,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+        inner_dim = dim_head * heads
+        cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+        self.upcast_attention = upcast_attention
+
+        self.scale = dim_head**-0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
+
+        self.to_out = nn.ModuleList([])
+        self.to_out.append(nn.Linear(inner_dim, query_dim))
+        # no dropout here
+
+        self.use_memory_efficient_attention_xformers = False
+        self.use_memory_efficient_attention_mem_eff = False
+        self.use_sdpa = False
+
+        # Attention processor
+        self.processor = None
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        self.use_memory_efficient_attention_xformers = xformers
+        self.use_memory_efficient_attention_mem_eff = mem_eff
+
+    def set_use_sdpa(self, sdpa):
+        self.use_sdpa = sdpa
+
+    def reshape_heads_to_batch_dim(self, tensor):
+        batch_size, seq_len, dim = tensor.shape
+        head_size = self.heads
+        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+        return tensor
+
+    def reshape_batch_dim_to_heads(self, tensor):
+        batch_size, seq_len, dim = tensor.shape
+        head_size = self.heads
+        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+        return tensor
+
+    def set_processor(self):
+        return self.processor
+
+    def get_processor(self):
+        return self.processor
+
+    def forward(self, hidden_states, context=None, mask=None, **kwargs):
+        if self.processor is not None:
+            (
+                hidden_states,
+                encoder_hidden_states,
+                attention_mask,
+            ) = translate_attention_names_from_diffusers(
+                hidden_states=hidden_states, context=context, mask=mask, **kwargs
+            )
+            return self.processor(
+                attn=self,
+                hidden_states=hidden_states,
+                encoder_hidden_states=context,
+                attention_mask=mask,
+                **kwargs
+            )
+        if self.use_memory_efficient_attention_xformers:
+            return self.forward_memory_efficient_xformers(hidden_states, context, mask)
+        if self.use_memory_efficient_attention_mem_eff:
+            return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
+        if self.use_sdpa:
+            return self.forward_sdpa(hidden_states, context, mask)
+
+        query = self.to_q(hidden_states)
+        context = context if context is not None else hidden_states
+        key = self.to_k(context)
+        value = self.to_v(context)
+
+        query = self.reshape_heads_to_batch_dim(query)
+        key = self.reshape_heads_to_batch_dim(key)
+        value = self.reshape_heads_to_batch_dim(value)
+
+        hidden_states = self._attention(query, key, value)
+
+        # linear proj
+        hidden_states = self.to_out[0](hidden_states)
+        # hidden_states = self.to_out[1](hidden_states)     # no dropout
+        return hidden_states
+
+    def _attention(self, query, key, value):
+        if self.upcast_attention:
+            query = query.float()
+            key = key.float()
+
+        attention_scores = torch.baddbmm(
+            torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+            query,
+            key.transpose(-1, -2),
+            beta=0,
+            alpha=self.scale,
+        )
+        attention_probs = attention_scores.softmax(dim=-1)
+
+        # cast back to the original dtype
+        attention_probs = attention_probs.to(value.dtype)
+
+        # compute attention output
+        hidden_states = torch.bmm(attention_probs, value)
+
+        # reshape hidden_states
+        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+        return hidden_states
+
+    # TODO support Hypernetworks
+    def forward_memory_efficient_xformers(self, x, context=None, mask=None):
+        import xformers.ops
+
+        h = self.heads
+        q_in = self.to_q(x)
+        context = context if context is not None else x
+        context = context.to(x.dtype)
+        k_in = self.to_k(context)
+        v_in = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
+        del q_in, k_in, v_in
+
+        q = q.contiguous()
+        k = k.contiguous()
+        v = v.contiguous()
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)  # 最適なのを選んでくれる
+
+        out = rearrange(out, "b n h d -> b n (h d)", h=h)
+
+        out = self.to_out[0](out)
+        return out
+
+    def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
+        flash_func = FlashAttentionFunction
+
+        q_bucket_size = 512
+        k_bucket_size = 1024
+
+        h = self.heads
+        q = self.to_q(x)
+        context = context if context is not None else x
+        context = context.to(x.dtype)
+        k = self.to_k(context)
+        v = self.to_v(context)
+        del context, x
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+        out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
+
+        out = rearrange(out, "b h n d -> b n (h d)")
+
+        out = self.to_out[0](out)
+        return out
+
+    def forward_sdpa(self, x, context=None, mask=None):
+        h = self.heads
+        q_in = self.to_q(x)
+        context = context if context is not None else x
+        context = context.to(x.dtype)
+        k_in = self.to_k(context)
+        v_in = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
+        del q_in, k_in, v_in
+
+        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+
+        out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+        out = self.to_out[0](out)
+        return out
+
+def translate_attention_names_from_diffusers(
+    hidden_states: torch.FloatTensor,
+    context: Optional[torch.FloatTensor] = None,
+    mask: Optional[torch.FloatTensor] = None,
+    # HF naming
+    encoder_hidden_states: Optional[torch.FloatTensor] = None,
+    attention_mask: Optional[torch.FloatTensor] = None
+):
+    # translate from hugging face diffusers
+    context = context if context is not None else encoder_hidden_states
+
+    # translate from hugging face diffusers
+    mask = mask if mask is not None else attention_mask
+
+    return hidden_states, context, mask
+
+# feedforward
+class GEGLU(nn.Module):
+    r"""
+    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+    Parameters:
+        dim_in (`int`): The number of channels in the input.
+        dim_out (`int`): The number of channels in the output.
+    """
+
+    def __init__(self, dim_in: int, dim_out: int):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def gelu(self, gate):
+        if gate.device.type != "mps":
+            return F.gelu(gate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states):
+        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+        return hidden_states * self.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+    ):
+        super().__init__()
+        inner_dim = int(dim * 4)  # mult is always 4
+
+        self.net = nn.ModuleList([])
+        # project in
+        self.net.append(GEGLU(dim, inner_dim))
+        # project dropout
+        self.net.append(nn.Identity())  # nn.Dropout(0)) # dummy for dropout with 0
+        # project out
+        self.net.append(nn.Linear(inner_dim, dim))
+
+    def forward(self, hidden_states):
+        for module in self.net:
+            hidden_states = module(hidden_states)
+        return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+    def __init__(
+        self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
+    ):
+        super().__init__()
+
+        # 1. Self-Attn
+        self.attn1 = CrossAttention(
+            query_dim=dim,
+            cross_attention_dim=None,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            upcast_attention=upcast_attention,
+        )
+        self.ff = FeedForward(dim)
+
+        # 2. Cross-Attn
+        self.attn2 = CrossAttention(
+            query_dim=dim,
+            cross_attention_dim=cross_attention_dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            upcast_attention=upcast_attention,
+        )
+
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+
+        # 3. Feed-forward
+        self.norm3 = nn.LayerNorm(dim)
+
+    def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
+        self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
+        self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa: bool):
+        self.attn1.set_use_sdpa(sdpa)
+        self.attn2.set_use_sdpa(sdpa)
+
+    def forward(self, hidden_states, context=None, timestep=None):
+        # 1. Self-Attention
+        norm_hidden_states = self.norm1(hidden_states)
+
+        hidden_states = self.attn1(norm_hidden_states) + hidden_states
+
+        # 2. Cross-Attention
+        norm_hidden_states = self.norm2(hidden_states)
+        hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
+
+        # 3. Feed-forward
+        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+        return hidden_states
+
+
+class Transformer2DModel(nn.Module):
+    def __init__(
+        self,
+        num_attention_heads: int = 16,
+        attention_head_dim: int = 88,
+        in_channels: Optional[int] = None,
+        cross_attention_dim: Optional[int] = None,
+        use_linear_projection: bool = False,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.num_attention_heads = num_attention_heads
+        self.attention_head_dim = attention_head_dim
+        inner_dim = num_attention_heads * attention_head_dim
+        self.use_linear_projection = use_linear_projection
+
+        self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
+
+        if use_linear_projection:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+        else:
+            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+        self.transformer_blocks = nn.ModuleList(
+            [
+                BasicTransformerBlock(
+                    inner_dim,
+                    num_attention_heads,
+                    attention_head_dim,
+                    cross_attention_dim=cross_attention_dim,
+                    upcast_attention=upcast_attention,
+                )
+            ]
+        )
+
+        if use_linear_projection:
+            self.proj_out = nn.Linear(in_channels, inner_dim)
+        else:
+            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        for transformer in self.transformer_blocks:
+            transformer.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa):
+        for transformer in self.transformer_blocks:
+            transformer.set_use_sdpa(sdpa)
+
+    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+        # 1. Input
+        batch, _, height, weight = hidden_states.shape
+        residual = hidden_states
+
+        hidden_states = self.norm(hidden_states)
+        if not self.use_linear_projection:
+            hidden_states = self.proj_in(hidden_states)
+            inner_dim = hidden_states.shape[1]
+            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+        else:
+            inner_dim = hidden_states.shape[1]
+            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+            hidden_states = self.proj_in(hidden_states)
+
+        # 2. Blocks
+        for block in self.transformer_blocks:
+            hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
+
+        # 3. Output
+        if not self.use_linear_projection:
+            hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+            hidden_states = self.proj_out(hidden_states)
+        else:
+            hidden_states = self.proj_out(hidden_states)
+            hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+        output = hidden_states + residual
+
+        if not return_dict:
+            return (output,)
+
+        return SampleOutput(sample=output)
+
+
+class CrossAttnDownBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        add_downsample=True,
+        cross_attention_dim=1280,
+        attn_num_head_channels=1,
+        use_linear_projection=False,
+        upcast_attention=False,
+    ):
+        super().__init__()
+        self.has_cross_attention = True
+        resnets = []
+        attentions = []
+
+        self.attn_num_head_channels = attn_num_head_channels
+
+        for i in range(LAYERS_PER_BLOCK):
+            in_channels = in_channels if i == 0 else out_channels
+
+            resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
+            attentions.append(
+                Transformer2DModel(
+                    attn_num_head_channels,
+                    out_channels // attn_num_head_channels,
+                    in_channels=out_channels,
+                    cross_attention_dim=cross_attention_dim,
+                    use_linear_projection=use_linear_projection,
+                    upcast_attention=upcast_attention,
+                )
+            )
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        for attn in self.attentions:
+            attn.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa):
+        for attn in self.attentions:
+            attn.set_use_sdpa(sdpa)
+
+    def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+        output_states = ()
+
+        for resnet, attn in zip(self.resnets, self.attentions):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+                )[0]
+            else:
+                hidden_states = resnet(hidden_states, temb)
+                hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        attn_num_head_channels=1,
+        cross_attention_dim=1280,
+        use_linear_projection=False,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+        self.attn_num_head_channels = attn_num_head_channels
+
+        # Middle block has two resnets and one attention
+        resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+            ),
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+            ),
+        ]
+        attentions = [
+            Transformer2DModel(
+                attn_num_head_channels,
+                in_channels // attn_num_head_channels,
+                in_channels=in_channels,
+                cross_attention_dim=cross_attention_dim,
+                use_linear_projection=use_linear_projection,
+            )
+        ]
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        self.gradient_checkpointing = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        for attn in self.attentions:
+            attn.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa):
+        for attn in self.attentions:
+            attn.set_use_sdpa(sdpa)
+
+    def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+        for i, resnet in enumerate(self.resnets):
+            attn = None if i == 0 else self.attentions[i - 1]
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                if attn is not None:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+                    )[0]
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+            else:
+                if attn is not None:
+                    hidden_states = attn(hidden_states, encoder_hidden_states).sample
+                hidden_states = resnet(hidden_states, temb)
+
+        return hidden_states
+
+
+class Upsample2D(nn.Module):
+    def __init__(self, channels, out_channels):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels
+        self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+    def forward(self, hidden_states, output_size):
+        assert hidden_states.shape[1] == self.channels
+
+        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+        # https://github.com/pytorch/pytorch/issues/86679
+        dtype = hidden_states.dtype
+        if dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(torch.float32)
+
+        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+        if hidden_states.shape[0] >= 64:
+            hidden_states = hidden_states.contiguous()
+
+        # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
+        if output_size is None:
+            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+        else:
+            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+        # If the input is bfloat16, we cast back to bfloat16
+        if dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(dtype)
+
+        hidden_states = self.conv(hidden_states)
+
+        return hidden_states
+
+
+class UpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        add_upsample=True,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = False
+        resnets = []
+
+        for i in range(LAYERS_PER_BLOCK_UP):
+            res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        pass
+
+    def set_use_sdpa(self, sdpa):
+        pass
+
+    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+        for resnet in self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+            else:
+                hidden_states = resnet(hidden_states, temb)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        prev_output_channel: int,
+        attn_num_head_channels=1,
+        cross_attention_dim=1280,
+        add_upsample=True,
+        use_linear_projection=False,
+        upcast_attention=False,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+        self.attn_num_head_channels = attn_num_head_channels
+
+        for i in range(LAYERS_PER_BLOCK_UP):
+            res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                )
+            )
+            attentions.append(
+                Transformer2DModel(
+                    attn_num_head_channels,
+                    out_channels // attn_num_head_channels,
+                    in_channels=out_channels,
+                    cross_attention_dim=cross_attention_dim,
+                    use_linear_projection=use_linear_projection,
+                    upcast_attention=upcast_attention,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        for attn in self.attentions:
+            attn.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, spda):
+        for attn in self.attentions:
+            attn.set_use_sdpa(spda)
+
+    def forward(
+        self,
+        hidden_states,
+        res_hidden_states_tuple,
+        temb=None,
+        encoder_hidden_states=None,
+        upsample_size=None,
+    ):
+        for resnet, attn in zip(self.resnets, self.attentions):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+                )[0]
+            else:
+                hidden_states = resnet(hidden_states, temb)
+                hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+
+def get_down_block(
+    down_block_type,
+    in_channels,
+    out_channels,
+    add_downsample,
+    attn_num_head_channels,
+    cross_attention_dim,
+    use_linear_projection,
+    upcast_attention,
+):
+    if down_block_type == "DownBlock2D":
+        return DownBlock2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            add_downsample=add_downsample,
+        )
+    elif down_block_type == "CrossAttnDownBlock2D":
+        return CrossAttnDownBlock2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            add_downsample=add_downsample,
+            cross_attention_dim=cross_attention_dim,
+            attn_num_head_channels=attn_num_head_channels,
+            use_linear_projection=use_linear_projection,
+            upcast_attention=upcast_attention,
+        )
+
+
+def get_up_block(
+    up_block_type,
+    in_channels,
+    out_channels,
+    prev_output_channel,
+    add_upsample,
+    attn_num_head_channels,
+    cross_attention_dim=None,
+    use_linear_projection=False,
+    upcast_attention=False,
+):
+    if up_block_type == "UpBlock2D":
+        return UpBlock2D(
+            in_channels=in_channels,
+            prev_output_channel=prev_output_channel,
+            out_channels=out_channels,
+            add_upsample=add_upsample,
+        )
+    elif up_block_type == "CrossAttnUpBlock2D":
+        return CrossAttnUpBlock2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            attn_num_head_channels=attn_num_head_channels,
+            cross_attention_dim=cross_attention_dim,
+            add_upsample=add_upsample,
+            use_linear_projection=use_linear_projection,
+            upcast_attention=upcast_attention,
+        )
+
+
+class UNet2DConditionModel(nn.Module):
+    _supports_gradient_checkpointing = True
+
+    def __init__(
+        self,
+        sample_size: Optional[int] = None,
+        attention_head_dim: Union[int, Tuple[int]] = 8,
+        cross_attention_dim: int = 1280,
+        use_linear_projection: bool = False,
+        upcast_attention: bool = False,
+        **kwargs,
+    ):
+        super().__init__()
+        assert sample_size is not None, "sample_size must be specified"
+        print(
+            f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
+        )
+
+        # 外部からの参照用に定義しておく
+        self.in_channels = IN_CHANNELS
+        self.out_channels = OUT_CHANNELS
+
+        self.sample_size = sample_size
+        self.prepare_config(sample_size=sample_size)
+
+        # state_dictの書式が変わるのでmoduleの持ち方は変えられない
+
+        # input
+        self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
+
+        # time
+        self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
+
+        self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
+
+        self.down_blocks = nn.ModuleList([])
+        self.mid_block = None
+        self.up_blocks = nn.ModuleList([])
+
+        if isinstance(attention_head_dim, int):
+            attention_head_dim = (attention_head_dim,) * 4
+
+        # down
+        output_channel = BLOCK_OUT_CHANNELS[0]
+        for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
+            input_channel = output_channel
+            output_channel = BLOCK_OUT_CHANNELS[i]
+            is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
+
+            down_block = get_down_block(
+                down_block_type,
+                in_channels=input_channel,
+                out_channels=output_channel,
+                add_downsample=not is_final_block,
+                attn_num_head_channels=attention_head_dim[i],
+                cross_attention_dim=cross_attention_dim,
+                use_linear_projection=use_linear_projection,
+                upcast_attention=upcast_attention,
+            )
+            self.down_blocks.append(down_block)
+
+        # mid
+        self.mid_block = UNetMidBlock2DCrossAttn(
+            in_channels=BLOCK_OUT_CHANNELS[-1],
+            attn_num_head_channels=attention_head_dim[-1],
+            cross_attention_dim=cross_attention_dim,
+            use_linear_projection=use_linear_projection,
+        )
+
+        # count how many layers upsample the images
+        self.num_upsamplers = 0
+
+        # up
+        reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
+        reversed_attention_head_dim = list(reversed(attention_head_dim))
+        output_channel = reversed_block_out_channels[0]
+        for i, up_block_type in enumerate(UP_BLOCK_TYPES):
+            is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
+
+            prev_output_channel = output_channel
+            output_channel = reversed_block_out_channels[i]
+            input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
+
+            # add upsample block for all BUT final layer
+            if not is_final_block:
+                add_upsample = True
+                self.num_upsamplers += 1
+            else:
+                add_upsample = False
+
+            up_block = get_up_block(
+                up_block_type,
+                in_channels=input_channel,
+                out_channels=output_channel,
+                prev_output_channel=prev_output_channel,
+                add_upsample=add_upsample,
+                attn_num_head_channels=reversed_attention_head_dim[i],
+                cross_attention_dim=cross_attention_dim,
+                use_linear_projection=use_linear_projection,
+                upcast_attention=upcast_attention,
+            )
+            self.up_blocks.append(up_block)
+            prev_output_channel = output_channel
+
+        # out
+        self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
+        self.conv_act = nn.SiLU()
+        self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
+
+    # region diffusers compatibility
+    def prepare_config(self, *args, **kwargs):
+        self.config = SimpleNamespace(**kwargs)
+
+    @property
+    def dtype(self) -> torch.dtype:
+        # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+        return get_parameter_dtype(self)
+
+    @property
+    def device(self) -> torch.device:
+        # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
+        return get_parameter_device(self)
+
+    def set_attention_slice(self, slice_size):
+        raise NotImplementedError("Attention slicing is not supported for this model.")
+
+    def is_gradient_checkpointing(self) -> bool:
+        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+    def enable_gradient_checkpointing(self):
+        self.set_gradient_checkpointing(value=True)
+
+    def disable_gradient_checkpointing(self):
+        self.set_gradient_checkpointing(value=False)
+
+    def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
+        modules = self.down_blocks + [self.mid_block] + self.up_blocks
+        for module in modules:
+            module.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa: bool) -> None:
+        modules = self.down_blocks + [self.mid_block] + self.up_blocks
+        for module in modules:
+            module.set_use_sdpa(sdpa)
+
+    def set_gradient_checkpointing(self, value=False):
+        modules = self.down_blocks + [self.mid_block] + self.up_blocks
+        for module in modules:
+            print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
+            module.gradient_checkpointing = value
+
+    # endregion
+
+    def forward(
+        self,
+        sample: torch.FloatTensor,
+        timestep: Union[torch.Tensor, float, int],
+        encoder_hidden_states: torch.Tensor,
+        class_labels: Optional[torch.Tensor] = None,
+        return_dict: bool = True,
+        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+        mid_block_additional_residual: Optional[torch.Tensor] = None,
+    ) -> Union[Dict, Tuple]:
+        r"""
+        Args:
+            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a dict instead of a plain tuple.
+
+        Returns:
+            `SampleOutput` or `tuple`:
+            `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+        """
+        # By default samples have to be AT least a multiple of the overall upsampling factor.
+        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+        # However, the upsampling interpolation output size can be forced to fit any upsampling size
+        # on the fly if necessary.
+        # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
+        # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
+        # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
+        default_overall_up_factor = 2**self.num_upsamplers
+
+        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+        # 64で割り切れないときはupsamplerにサイズを伝える
+        forward_upsample_size = False
+        upsample_size = None
+
+        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+            # logger.info("Forward upsample size to force interpolation output size.")
+            forward_upsample_size = True
+
+        # 1. time
+        timesteps = timestep
+        timesteps = self.handle_unusual_timesteps(sample, timesteps)  # 変な時だけ処理
+
+        t_emb = self.time_proj(timesteps)
+
+        # timesteps does not contain any weights and will always return f32 tensors
+        # but time_embedding might actually be running in fp16. so we need to cast here.
+        # there might be better ways to encapsulate this.
+        # timestepsは重みを含まないので常にfloat32のテンソルを返す
+        # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
+        # time_projでキャストしておけばいいんじゃね?
+        t_emb = t_emb.to(dtype=self.dtype)
+        emb = self.time_embedding(t_emb)
+
+        # 2. pre-process
+        sample = self.conv_in(sample)
+
+        down_block_res_samples = (sample,)
+        for downsample_block in self.down_blocks:
+            # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
+            # まあこちらのほうがわかりやすいかもしれない
+            if downsample_block.has_cross_attention:
+                sample, res_samples = downsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    encoder_hidden_states=encoder_hidden_states,
+                )
+            else:
+                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+            down_block_res_samples += res_samples
+
+        # skip connectionにControlNetの出力を追加する
+        if down_block_additional_residuals is not None:
+            down_block_res_samples = list(down_block_res_samples)
+            for i in range(len(down_block_res_samples)):
+                down_block_res_samples[i] += down_block_additional_residuals[i]
+            down_block_res_samples = tuple(down_block_res_samples)
+
+        # 4. mid
+        sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+        # ControlNetの出力を追加する
+        if mid_block_additional_residual is not None:
+            sample += mid_block_additional_residual
+
+        # 5. up
+        for i, upsample_block in enumerate(self.up_blocks):
+            is_final_block = i == len(self.up_blocks) - 1
+
+            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]  # skip connection
+
+            # if we have not reached the final block and need to forward the upsample size, we do it here
+            # 前述のように最後のブロック以外ではupsample_sizeを伝える
+            if not is_final_block and forward_upsample_size:
+                upsample_size = down_block_res_samples[-1].shape[2:]
+
+            if upsample_block.has_cross_attention:
+                sample = upsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    res_hidden_states_tuple=res_samples,
+                    encoder_hidden_states=encoder_hidden_states,
+                    upsample_size=upsample_size,
+                )
+            else:
+                sample = upsample_block(
+                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+                )
+
+        # 6. post-process
+        sample = self.conv_norm_out(sample)
+        sample = self.conv_act(sample)
+        sample = self.conv_out(sample)
+
+        if not return_dict:
+            return (sample,)
+
+        return SampleOutput(sample=sample)
+
+    def handle_unusual_timesteps(self, sample, timesteps):
+        r"""
+        timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
+        """
+        if not torch.is_tensor(timesteps):
+            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+            # This would be a good case for the `match` statement (Python 3.10+)
+            is_mps = sample.device.type == "mps"
+            if isinstance(timesteps, float):
+                dtype = torch.float32 if is_mps else torch.float64
+            else:
+                dtype = torch.int32 if is_mps else torch.int64
+            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+        elif len(timesteps.shape) == 0:
+            timesteps = timesteps[None].to(sample.device)
+
+        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+        timesteps = timesteps.expand(sample.shape[0])
+
+        return timesteps
+
+
+class InferUNet2DConditionModel:
+    def __init__(self, original_unet: UNet2DConditionModel):
+        self.delegate = original_unet
+
+        # override original model's forward method: because forward is not called by `__call__`
+        # overriding `__call__` is not enough, because nn.Module.forward has a special handling
+        self.delegate.forward = self.forward
+
+        # override original model's up blocks' forward method
+        for up_block in self.delegate.up_blocks:
+            if up_block.__class__.__name__ == "UpBlock2D":
+
+                def resnet_wrapper(func, block):
+                    def forward(*args, **kwargs):
+                        return func(block, *args, **kwargs)
+
+                    return forward
+
+                up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
+
+            elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
+
+                def cross_attn_up_wrapper(func, block):
+                    def forward(*args, **kwargs):
+                        return func(block, *args, **kwargs)
+
+                    return forward
+
+                up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
+
+        # Deep Shrink
+        self.ds_depth_1 = None
+        self.ds_depth_2 = None
+        self.ds_timesteps_1 = None
+        self.ds_timesteps_2 = None
+        self.ds_ratio = None
+
+    # call original model's methods
+    def __getattr__(self, name):
+        return getattr(self.delegate, name)
+
+    def __call__(self, *args, **kwargs):
+        return self.delegate(*args, **kwargs)
+
+    def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
+        if ds_depth_1 is None:
+            print("Deep Shrink is disabled.")
+            self.ds_depth_1 = None
+            self.ds_timesteps_1 = None
+            self.ds_depth_2 = None
+            self.ds_timesteps_2 = None
+            self.ds_ratio = None
+        else:
+            print(
+                f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
+            )
+            self.ds_depth_1 = ds_depth_1
+            self.ds_timesteps_1 = ds_timesteps_1
+            self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
+            self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
+            self.ds_ratio = ds_ratio
+
+    def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+        for resnet in _self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # Deep Shrink
+            if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
+                hidden_states = resize_like(hidden_states, res_hidden_states)
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+            hidden_states = resnet(hidden_states, temb)
+
+        if _self.upsamplers is not None:
+            for upsampler in _self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+    def cross_attn_up_block_forward(
+        self,
+        _self,
+        hidden_states,
+        res_hidden_states_tuple,
+        temb=None,
+        encoder_hidden_states=None,
+        upsample_size=None,
+    ):
+        for resnet, attn in zip(_self.resnets, _self.attentions):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # Deep Shrink
+            if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
+                hidden_states = resize_like(hidden_states, res_hidden_states)
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+            hidden_states = resnet(hidden_states, temb)
+            hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+        if _self.upsamplers is not None:
+            for upsampler in _self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+    def forward(
+        self,
+        sample: torch.FloatTensor,
+        timestep: Union[torch.Tensor, float, int],
+        encoder_hidden_states: torch.Tensor,
+        class_labels: Optional[torch.Tensor] = None,
+        return_dict: bool = True,
+        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+        mid_block_additional_residual: Optional[torch.Tensor] = None,
+    ) -> Union[Dict, Tuple]:
+        r"""
+        current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
+        """
+
+        r"""
+        Args:
+            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a dict instead of a plain tuple.
+
+        Returns:
+            `SampleOutput` or `tuple`:
+            `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+        """
+
+        _self = self.delegate
+
+        # By default samples have to be AT least a multiple of the overall upsampling factor.
+        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+        # However, the upsampling interpolation output size can be forced to fit any upsampling size
+        # on the fly if necessary.
+        # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
+        # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
+        # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
+        default_overall_up_factor = 2**_self.num_upsamplers
+
+        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+        # 64で割り切れないときはupsamplerにサイズを伝える
+        forward_upsample_size = False
+        upsample_size = None
+
+        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+            # logger.info("Forward upsample size to force interpolation output size.")
+            forward_upsample_size = True
+
+        # 1. time
+        timesteps = timestep
+        timesteps = _self.handle_unusual_timesteps(sample, timesteps)  # 変な時だけ処理
+
+        t_emb = _self.time_proj(timesteps)
+
+        # timesteps does not contain any weights and will always return f32 tensors
+        # but time_embedding might actually be running in fp16. so we need to cast here.
+        # there might be better ways to encapsulate this.
+        # timestepsは重みを含まないので常にfloat32のテンソルを返す
+        # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
+        # time_projでキャストしておけばいいんじゃね?
+        t_emb = t_emb.to(dtype=_self.dtype)
+        emb = _self.time_embedding(t_emb)
+
+        # 2. pre-process
+        sample = _self.conv_in(sample)
+
+        down_block_res_samples = (sample,)
+        for depth, downsample_block in enumerate(_self.down_blocks):
+            # Deep Shrink
+            if self.ds_depth_1 is not None:
+                if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
+                    self.ds_depth_2 is not None
+                    and depth == self.ds_depth_2
+                    and timesteps[0] < self.ds_timesteps_1
+                    and timesteps[0] >= self.ds_timesteps_2
+                ):
+                    org_dtype = sample.dtype
+                    if org_dtype == torch.bfloat16:
+                        sample = sample.to(torch.float32)
+                    sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
+
+            # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
+            # まあこちらのほうがわかりやすいかもしれない
+            if downsample_block.has_cross_attention:
+                sample, res_samples = downsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    encoder_hidden_states=encoder_hidden_states,
+                )
+            else:
+                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+            down_block_res_samples += res_samples
+
+        # skip connectionにControlNetの出力を追加する
+        if down_block_additional_residuals is not None:
+            down_block_res_samples = list(down_block_res_samples)
+            for i in range(len(down_block_res_samples)):
+                down_block_res_samples[i] += down_block_additional_residuals[i]
+            down_block_res_samples = tuple(down_block_res_samples)
+
+        # 4. mid
+        sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+        # ControlNetの出力を追加する
+        if mid_block_additional_residual is not None:
+            sample += mid_block_additional_residual
+
+        # 5. up
+        for i, upsample_block in enumerate(_self.up_blocks):
+            is_final_block = i == len(_self.up_blocks) - 1
+
+            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]  # skip connection
+
+            # if we have not reached the final block and need to forward the upsample size, we do it here
+            # 前述のように最後のブロック以外ではupsample_sizeを伝える
+            if not is_final_block and forward_upsample_size:
+                upsample_size = down_block_res_samples[-1].shape[2:]
+
+            if upsample_block.has_cross_attention:
+                sample = upsample_block(
+                    hidden_states=sample,
+                    temb=emb,
+                    res_hidden_states_tuple=res_samples,
+                    encoder_hidden_states=encoder_hidden_states,
+                    upsample_size=upsample_size,
+                )
+            else:
+                sample = upsample_block(
+                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+                )
+
+        # 6. post-process
+        sample = _self.conv_norm_out(sample)
+        sample = _self.conv_act(sample)
+        sample = _self.conv_out(sample)
+
+        if not return_dict:
+            return (sample,)
+
+        return SampleOutput(sample=sample)
diff --git a/external/llite/library/sai_model_spec.py b/external/llite/library/sai_model_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..472686ba404b70664e0db007d9330101c7b2d1dc
--- /dev/null
+++ b/external/llite/library/sai_model_spec.py
@@ -0,0 +1,305 @@
+# based on https://github.com/Stability-AI/ModelSpec
+import datetime
+import hashlib
+from io import BytesIO
+import os
+from typing import List, Optional, Tuple, Union
+import safetensors
+
+r"""
+# Metadata Example
+metadata = {
+    # === Must ===
+    "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
+    "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
+    "modelspec.implementation": "sgm",
+    "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
+    # === Should ===
+    "modelspec.author": "Example Corp", # Your name or company name
+    "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
+    "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
+    # === Can ===
+    "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
+    "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
+}
+"""
+
+BASE_METADATA = {
+    # === Must ===
+    "modelspec.sai_model_spec": "1.0.0",  # Required version ID for the spec
+    "modelspec.architecture": None,
+    "modelspec.implementation": None,
+    "modelspec.title": None,
+    "modelspec.resolution": None,
+    # === Should ===
+    "modelspec.description": None,
+    "modelspec.author": None,
+    "modelspec.date": None,
+    # === Can ===
+    "modelspec.license": None,
+    "modelspec.tags": None,
+    "modelspec.merged_from": None,
+    "modelspec.prediction_type": None,
+    "modelspec.timestep_range": None,
+    "modelspec.encoder_layer": None,
+}
+
+# 別に使うやつだけ定義
+MODELSPEC_TITLE = "modelspec.title"
+
+ARCH_SD_V1 = "stable-diffusion-v1"
+ARCH_SD_V2_512 = "stable-diffusion-v2-512"
+ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
+ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
+
+ADAPTER_LORA = "lora"
+ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
+
+IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
+IMPL_DIFFUSERS = "diffusers"
+
+PRED_TYPE_EPSILON = "epsilon"
+PRED_TYPE_V = "v"
+
+
+def load_bytes_in_safetensors(tensors):
+    bytes = safetensors.torch.save(tensors)
+    b = BytesIO(bytes)
+
+    b.seek(0)
+    header = b.read(8)
+    n = int.from_bytes(header, "little")
+
+    offset = n + 8
+    b.seek(offset)
+
+    return b.read()
+
+
+def precalculate_safetensors_hashes(state_dict):
+    # calculate each tensor one by one to reduce memory usage
+    hash_sha256 = hashlib.sha256()
+    for tensor in state_dict.values():
+        single_tensor_sd = {"tensor": tensor}
+        bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
+        hash_sha256.update(bytes_for_tensor)
+
+    return f"0x{hash_sha256.hexdigest()}"
+
+
+def update_hash_sha256(metadata: dict, state_dict: dict):
+    raise NotImplementedError
+
+
+def build_metadata(
+    state_dict: Optional[dict],
+    v2: bool,
+    v_parameterization: bool,
+    sdxl: bool,
+    lora: bool,
+    textual_inversion: bool,
+    timestamp: float,
+    title: Optional[str] = None,
+    reso: Optional[Union[int, Tuple[int, int]]] = None,
+    is_stable_diffusion_ckpt: Optional[bool] = None,
+    author: Optional[str] = None,
+    description: Optional[str] = None,
+    license: Optional[str] = None,
+    tags: Optional[str] = None,
+    merged_from: Optional[str] = None,
+    timesteps: Optional[Tuple[int, int]] = None,
+    clip_skip: Optional[int] = None,
+):
+    # if state_dict is None, hash is not calculated
+
+    metadata = {}
+    metadata.update(BASE_METADATA)
+
+    # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
+    # if state_dict is not None:
+    # hash = precalculate_safetensors_hashes(state_dict)
+    # metadata["modelspec.hash_sha256"] = hash
+
+    if sdxl:
+        arch = ARCH_SD_XL_V1_BASE
+    elif v2:
+        if v_parameterization:
+            arch = ARCH_SD_V2_768_V
+        else:
+            arch = ARCH_SD_V2_512
+    else:
+        arch = ARCH_SD_V1
+
+    if lora:
+        arch += f"/{ADAPTER_LORA}"
+    elif textual_inversion:
+        arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
+
+    metadata["modelspec.architecture"] = arch
+
+    if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
+        is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
+
+    if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
+        # Stable Diffusion ckpt, TI, SDXL LoRA
+        impl = IMPL_STABILITY_AI
+    else:
+        # v1/v2 LoRA or Diffusers
+        impl = IMPL_DIFFUSERS
+    metadata["modelspec.implementation"] = impl
+
+    if title is None:
+        if lora:
+            title = "LoRA"
+        elif textual_inversion:
+            title = "TextualInversion"
+        else:
+            title = "Checkpoint"
+        title += f"@{timestamp}"
+    metadata[MODELSPEC_TITLE] = title
+
+    if author is not None:
+        metadata["modelspec.author"] = author
+    else:
+        del metadata["modelspec.author"]
+
+    if description is not None:
+        metadata["modelspec.description"] = description
+    else:
+        del metadata["modelspec.description"]
+
+    if merged_from is not None:
+        metadata["modelspec.merged_from"] = merged_from
+    else:
+        del metadata["modelspec.merged_from"]
+
+    if license is not None:
+        metadata["modelspec.license"] = license
+    else:
+        del metadata["modelspec.license"]
+
+    if tags is not None:
+        metadata["modelspec.tags"] = tags
+    else:
+        del metadata["modelspec.tags"]
+
+    # remove microsecond from time
+    int_ts = int(timestamp)
+
+    # time to iso-8601 compliant date
+    date = datetime.datetime.fromtimestamp(int_ts).isoformat()
+    metadata["modelspec.date"] = date
+
+    if reso is not None:
+        # comma separated to tuple
+        if isinstance(reso, str):
+            reso = tuple(map(int, reso.split(",")))
+        if len(reso) == 1:
+            reso = (reso[0], reso[0])
+    else:
+        # resolution is defined in dataset, so use default
+        if sdxl:
+            reso = 1024
+        elif v2 and v_parameterization:
+            reso = 768
+        else:
+            reso = 512
+    if isinstance(reso, int):
+        reso = (reso, reso)
+
+    metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
+
+    if v_parameterization:
+        metadata["modelspec.prediction_type"] = PRED_TYPE_V
+    else:
+        metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
+
+    if timesteps is not None:
+        if isinstance(timesteps, str) or isinstance(timesteps, int):
+            timesteps = (timesteps, timesteps)
+        if len(timesteps) == 1:
+            timesteps = (timesteps[0], timesteps[0])
+        metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
+    else:
+        del metadata["modelspec.timestep_range"]
+
+    if clip_skip is not None:
+        metadata["modelspec.encoder_layer"] = f"{clip_skip}"
+    else:
+        del metadata["modelspec.encoder_layer"]
+
+    # # assert all values are filled
+    # assert all([v is not None for v in metadata.values()]), metadata
+    if not all([v is not None for v in metadata.values()]):
+        print(f"Internal error: some metadata values are None: {metadata}")
+    
+    return metadata
+
+
+# region utils
+
+
+def get_title(metadata: dict) -> Optional[str]:
+    return metadata.get(MODELSPEC_TITLE, None)
+
+
+def load_metadata_from_safetensors(model: str) -> dict:
+    if not model.endswith(".safetensors"):
+        return {}
+    
+    with safetensors.safe_open(model, framework="pt") as f:
+        metadata = f.metadata()
+    if metadata is None:
+        metadata = {}
+    return metadata
+
+
+def build_merged_from(models: List[str]) -> str:
+    def get_title(model: str):
+        metadata = load_metadata_from_safetensors(model)
+        title = metadata.get(MODELSPEC_TITLE, None)
+        if title is None:
+            title = os.path.splitext(os.path.basename(model))[0]  # use filename
+        return title
+
+    titles = [get_title(model) for model in models]
+    return ", ".join(titles)
+
+
+# endregion
+
+
+r"""
+if __name__ == "__main__":
+    import argparse
+    import torch
+    from safetensors.torch import load_file
+    from library import train_util
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--ckpt", type=str, required=True)
+    args = parser.parse_args()
+
+    print(f"Loading {args.ckpt}")
+    state_dict = load_file(args.ckpt)
+
+    print(f"Calculating metadata")
+    metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
+    print(metadata)
+    del state_dict
+
+    # by reference implementation
+    with open(args.ckpt, mode="rb") as file_data:
+        file_hash = hashlib.sha256()
+        head_len = struct.unpack("Q", file_data.read(8))  # int64 header length prefix
+        header = json.loads(file_data.read(head_len[0]))  # header itself, json string
+        content = (
+            file_data.read()
+        )  # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
+        file_hash.update(content)
+        # ===== Update the hash for modelspec =====
+        by_ref = f"0x{file_hash.hexdigest()}"
+    print(by_ref)
+    print("is same?", by_ref == metadata["modelspec.hash_sha256"])
+
+"""
diff --git a/external/llite/library/sdxl_lpw_stable_diffusion.py b/external/llite/library/sdxl_lpw_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..974adf716a84b3a5c2752b6904d63b16f6922ef9
--- /dev/null
+++ b/external/llite/library/sdxl_lpw_stable_diffusion.py
@@ -0,0 +1,1342 @@
+# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
+# and modify to support SD2.x
+
+import inspect
+import re
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from tqdm import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import SchedulerMixin, StableDiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.utils import logging
+from PIL import Image
+
+from external.llite.library import sdxl_model_util, sdxl_train_util, train_util
+
+
+try:
+    from diffusers.utils import PIL_INTERPOLATION
+except ImportError:
+    if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+        PIL_INTERPOLATION = {
+            "linear": PIL.Image.Resampling.BILINEAR,
+            "bilinear": PIL.Image.Resampling.BILINEAR,
+            "bicubic": PIL.Image.Resampling.BICUBIC,
+            "lanczos": PIL.Image.Resampling.LANCZOS,
+            "nearest": PIL.Image.Resampling.NEAREST,
+        }
+    else:
+        PIL_INTERPOLATION = {
+            "linear": PIL.Image.LINEAR,
+            "bilinear": PIL.Image.BILINEAR,
+            "bicubic": PIL.Image.BICUBIC,
+            "lanczos": PIL.Image.LANCZOS,
+            "nearest": PIL.Image.NEAREST,
+        }
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+re_attention = re.compile(
+    r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+    re.X,
+)
+
+
+def parse_prompt_attention(text):
+    """
+    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+    Accepted tokens are:
+      (abc) - increases attention to abc by a multiplier of 1.1
+      (abc:3.12) - increases attention to abc by a multiplier of 3.12
+      [abc] - decreases attention to abc by a multiplier of 1.1
+      \( - literal character '('
+      \[ - literal character '['
+      \) - literal character ')'
+      \] - literal character ']'
+      \\ - literal character '\'
+      anything else - just text
+    >>> parse_prompt_attention('normal text')
+    [['normal text', 1.0]]
+    >>> parse_prompt_attention('an (important) word')
+    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+    >>> parse_prompt_attention('(unbalanced')
+    [['unbalanced', 1.1]]
+    >>> parse_prompt_attention('\(literal\]')
+    [['(literal]', 1.0]]
+    >>> parse_prompt_attention('(unnecessary)(parens)')
+    [['unnecessaryparens', 1.1]]
+    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+    [['a ', 1.0],
+     ['house', 1.5730000000000004],
+     [' ', 1.1],
+     ['on', 1.0],
+     [' a ', 1.1],
+     ['hill', 0.55],
+     [', sun, ', 1.1],
+     ['sky', 1.4641000000000006],
+     ['.', 1.1]]
+    """
+
+    res = []
+    round_brackets = []
+    square_brackets = []
+
+    round_bracket_multiplier = 1.1
+    square_bracket_multiplier = 1 / 1.1
+
+    def multiply_range(start_position, multiplier):
+        for p in range(start_position, len(res)):
+            res[p][1] *= multiplier
+
+    for m in re_attention.finditer(text):
+        text = m.group(0)
+        weight = m.group(1)
+
+        if text.startswith("\\"):
+            res.append([text[1:], 1.0])
+        elif text == "(":
+            round_brackets.append(len(res))
+        elif text == "[":
+            square_brackets.append(len(res))
+        elif weight is not None and len(round_brackets) > 0:
+            multiply_range(round_brackets.pop(), float(weight))
+        elif text == ")" and len(round_brackets) > 0:
+            multiply_range(round_brackets.pop(), round_bracket_multiplier)
+        elif text == "]" and len(square_brackets) > 0:
+            multiply_range(square_brackets.pop(), square_bracket_multiplier)
+        else:
+            res.append([text, 1.0])
+
+    for pos in round_brackets:
+        multiply_range(pos, round_bracket_multiplier)
+
+    for pos in square_brackets:
+        multiply_range(pos, square_bracket_multiplier)
+
+    if len(res) == 0:
+        res = [["", 1.0]]
+
+    # merge runs of identical weights
+    i = 0
+    while i + 1 < len(res):
+        if res[i][1] == res[i + 1][1]:
+            res[i][0] += res[i + 1][0]
+            res.pop(i + 1)
+        else:
+            i += 1
+
+    return res
+
+
+def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
+    r"""
+    Tokenize a list of prompts and return its tokens with weights of each token.
+
+    No padding, starting or ending token is included.
+    """
+    tokens = []
+    weights = []
+    truncated = False
+    for text in prompt:
+        texts_and_weights = parse_prompt_attention(text)
+        text_token = []
+        text_weight = []
+        for word, weight in texts_and_weights:
+            # tokenize and discard the starting and the ending token
+            token = pipe.tokenizer(word).input_ids[1:-1]
+            text_token += token
+            # copy the weight by length of token
+            text_weight += [weight] * len(token)
+            # stop if the text is too long (longer than truncation limit)
+            if len(text_token) > max_length:
+                truncated = True
+                break
+        # truncate
+        if len(text_token) > max_length:
+            truncated = True
+            text_token = text_token[:max_length]
+            text_weight = text_weight[:max_length]
+        tokens.append(text_token)
+        weights.append(text_weight)
+    if truncated:
+        logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+    return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+    r"""
+    Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+    """
+    max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+    weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+    for i in range(len(tokens)):
+        tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
+        if no_boseos_middle:
+            weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+        else:
+            w = []
+            if len(weights[i]) == 0:
+                w = [1.0] * weights_length
+            else:
+                for j in range(max_embeddings_multiples):
+                    w.append(1.0)  # weight for starting token in this chunk
+                    w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+                    w.append(1.0)  # weight for ending token in this chunk
+                w += [1.0] * (weights_length - len(w))
+            weights[i] = w[:]
+
+    return tokens, weights
+
+
+def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
+    if not is_sdxl_text_encoder2:
+        # text_encoder1: same as SD1/2
+        enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
+        hidden_states = enc_out["hidden_states"][11]
+        pool = None
+    else:
+        # text_encoder2
+        enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
+        hidden_states = enc_out["hidden_states"][-2]  # penuultimate layer
+        # pool = enc_out["text_embeds"]
+        pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
+    hidden_states = hidden_states.to(device)
+    if pool is not None:
+        pool = pool.to(device)
+    return hidden_states, pool
+
+
+def get_unweighted_text_embeddings(
+    pipe: StableDiffusionPipeline,
+    text_input: torch.Tensor,
+    chunk_length: int,
+    clip_skip: int,
+    eos: int,
+    pad: int,
+    is_sdxl_text_encoder2: bool,
+    no_boseos_middle: Optional[bool] = True,
+):
+    """
+    When the length of tokens is a multiple of the capacity of the text encoder,
+    it should be split into chunks and sent to the text encoder individually.
+    """
+    max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+    text_pool = None
+    if max_embeddings_multiples > 1:
+        text_embeddings = []
+        for i in range(max_embeddings_multiples):
+            # extract the i-th chunk
+            text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+            # cover the head and the tail by the starting and the ending tokens
+            text_input_chunk[:, 0] = text_input[0, 0]
+            if pad == eos:  # v1
+                text_input_chunk[:, -1] = text_input[0, -1]
+            else:  # v2
+                for j in range(len(text_input_chunk)):
+                    if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
+                        text_input_chunk[j, -1] = eos
+                    if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
+                        text_input_chunk[j, 1] = eos
+
+            text_embedding, current_text_pool = get_hidden_states(
+                pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
+            )
+            if text_pool is None:
+                text_pool = current_text_pool
+
+            if no_boseos_middle:
+                if i == 0:
+                    # discard the ending token
+                    text_embedding = text_embedding[:, :-1]
+                elif i == max_embeddings_multiples - 1:
+                    # discard the starting token
+                    text_embedding = text_embedding[:, 1:]
+                else:
+                    # discard both starting and ending tokens
+                    text_embedding = text_embedding[:, 1:-1]
+
+            text_embeddings.append(text_embedding)
+        text_embeddings = torch.concat(text_embeddings, axis=1)
+    else:
+        text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
+    return text_embeddings, text_pool
+
+
+def get_weighted_text_embeddings(
+    pipe,  # : SdxlStableDiffusionLongPromptWeightingPipeline,
+    prompt: Union[str, List[str]],
+    uncond_prompt: Optional[Union[str, List[str]]] = None,
+    max_embeddings_multiples: Optional[int] = 3,
+    no_boseos_middle: Optional[bool] = False,
+    skip_parsing: Optional[bool] = False,
+    skip_weighting: Optional[bool] = False,
+    clip_skip=None,
+    is_sdxl_text_encoder2=False,
+):
+    r"""
+    Prompts can be assigned with local weights using brackets. For example,
+    prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+    and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+    Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+    Args:
+        pipe (`StableDiffusionPipeline`):
+            Pipe to provide access to the tokenizer and the text encoder.
+        prompt (`str` or `List[str]`):
+            The prompt or prompts to guide the image generation.
+        uncond_prompt (`str` or `List[str]`):
+            The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+            is provided, the embeddings of prompt and uncond_prompt are concatenated.
+        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+            The max multiple length of prompt embeddings compared to the max output length of text encoder.
+        no_boseos_middle (`bool`, *optional*, defaults to `False`):
+            If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+            ending token in each of the chunk in the middle.
+        skip_parsing (`bool`, *optional*, defaults to `False`):
+            Skip the parsing of brackets.
+        skip_weighting (`bool`, *optional*, defaults to `False`):
+            Skip the weighting. When the parsing is skipped, it is forced True.
+    """
+    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+    if isinstance(prompt, str):
+        prompt = [prompt]
+
+    if not skip_parsing:
+        prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+        if uncond_prompt is not None:
+            if isinstance(uncond_prompt, str):
+                uncond_prompt = [uncond_prompt]
+            uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+    else:
+        prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
+        prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+        if uncond_prompt is not None:
+            if isinstance(uncond_prompt, str):
+                uncond_prompt = [uncond_prompt]
+            uncond_tokens = [
+                token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
+            ]
+            uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+    # round up the longest length of tokens to a multiple of (model_max_length - 2)
+    max_length = max([len(token) for token in prompt_tokens])
+    if uncond_prompt is not None:
+        max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+    max_embeddings_multiples = min(
+        max_embeddings_multiples,
+        (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+    )
+    max_embeddings_multiples = max(1, max_embeddings_multiples)
+    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+    # pad the length of tokens and weights
+    bos = pipe.tokenizer.bos_token_id
+    eos = pipe.tokenizer.eos_token_id
+    pad = pipe.tokenizer.pad_token_id
+    prompt_tokens, prompt_weights = pad_tokens_and_weights(
+        prompt_tokens,
+        prompt_weights,
+        max_length,
+        bos,
+        eos,
+        pad,
+        no_boseos_middle=no_boseos_middle,
+        chunk_length=pipe.tokenizer.model_max_length,
+    )
+    prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
+    if uncond_prompt is not None:
+        uncond_tokens, uncond_weights = pad_tokens_and_weights(
+            uncond_tokens,
+            uncond_weights,
+            max_length,
+            bos,
+            eos,
+            pad,
+            no_boseos_middle=no_boseos_middle,
+            chunk_length=pipe.tokenizer.model_max_length,
+        )
+        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
+
+    # get the embeddings
+    text_embeddings, text_pool = get_unweighted_text_embeddings(
+        pipe,
+        prompt_tokens,
+        pipe.tokenizer.model_max_length,
+        clip_skip,
+        eos,
+        pad,
+        is_sdxl_text_encoder2,
+        no_boseos_middle=no_boseos_middle,
+    )
+    prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
+
+    if uncond_prompt is not None:
+        uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
+            pipe,
+            uncond_tokens,
+            pipe.tokenizer.model_max_length,
+            clip_skip,
+            eos,
+            pad,
+            is_sdxl_text_encoder2,
+            no_boseos_middle=no_boseos_middle,
+        )
+        uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
+
+    # assign weights to the prompts and normalize in the sense of mean
+    # TODO: should we normalize by chunk or in a whole (current implementation)?
+    if (not skip_parsing) and (not skip_weighting):
+        previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+        text_embeddings *= prompt_weights.unsqueeze(-1)
+        current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+        text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+        if uncond_prompt is not None:
+            previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+            uncond_embeddings *= uncond_weights.unsqueeze(-1)
+            current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+            uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+
+    if uncond_prompt is not None:
+        return text_embeddings, text_pool, uncond_embeddings, uncond_pool
+    return text_embeddings, text_pool, None, None
+
+
+def preprocess_image(image):
+    w, h = image.size
+    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
+    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+    image = np.array(image).astype(np.float32) / 255.0
+    image = image[None].transpose(0, 3, 1, 2)
+    image = torch.from_numpy(image)
+    return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+    mask = mask.convert("L")
+    w, h = mask.size
+    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
+    mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+    mask = np.array(mask).astype(np.float32) / 255.0
+    mask = np.tile(mask, (4, 1, 1))
+    mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?
+    mask = 1 - mask  # repaint white, keep black
+    mask = torch.from_numpy(mask)
+    return mask
+
+
+def prepare_controlnet_image(
+    image: PIL.Image.Image,
+    width: int,
+    height: int,
+    batch_size: int,
+    num_images_per_prompt: int,
+    device: torch.device,
+    dtype: torch.dtype,
+    do_classifier_free_guidance: bool = False,
+    guess_mode: bool = False,
+):
+    if not isinstance(image, torch.Tensor):
+        if isinstance(image, PIL.Image.Image):
+            image = [image]
+
+        if isinstance(image[0], PIL.Image.Image):
+            images = []
+
+            for image_ in image:
+                image_ = image_.convert("RGB")
+                image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+                image_ = np.array(image_)
+                image_ = image_[None, :]
+                images.append(image_)
+
+            image = images
+
+            image = np.concatenate(image, axis=0)
+            image = np.array(image).astype(np.float32) / 255.0
+            image = image.transpose(0, 3, 1, 2)
+            image = torch.from_numpy(image)
+        elif isinstance(image[0], torch.Tensor):
+            image = torch.cat(image, dim=0)
+
+    image_batch_size = image.shape[0]
+
+    if image_batch_size == 1:
+        repeat_by = batch_size
+    else:
+        # image batch size is the same as prompt batch size
+        repeat_by = num_images_per_prompt
+
+    image = image.repeat_interleave(repeat_by, dim=0)
+
+    image = image.to(device=device, dtype=dtype)
+
+    if do_classifier_free_guidance and not guess_mode:
+        image = torch.cat([image] * 2)
+
+    return image
+
+
+class SdxlStableDiffusionLongPromptWeightingPipeline:
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
+    weighting in prompt.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+        feature_extractor ([`CLIPFeatureExtractor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+
+    # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: List[CLIPTextModel],
+        tokenizer: List[CLIPTokenizer],
+        unet: UNet2DConditionModel,
+        scheduler: SchedulerMixin,
+        # clip_skip: int,
+        safety_checker: StableDiffusionSafetyChecker,
+        feature_extractor: CLIPFeatureExtractor,
+        requires_safety_checker: bool = True,
+        clip_skip: int = 1,
+    ):
+        # clip skip is ignored currently
+        self.tokenizer = tokenizer[0]
+        self.text_encoder = text_encoder[0]
+        self.unet = unet
+        self.scheduler = scheduler
+        self.safety_checker = safety_checker
+        self.feature_extractor = feature_extractor
+        self.requires_safety_checker = requires_safety_checker
+        self.vae = vae
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.progress_bar = lambda x: tqdm(x, leave=False)
+
+        self.clip_skip = clip_skip
+        self.tokenizers = tokenizer
+        self.text_encoders = text_encoder
+
+    #     self.__init__additional__()
+
+    # def __init__additional__(self):
+    #     if not hasattr(self, "vae_scale_factor"):
+    #         setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
+
+    def to(self, device=None, dtype=None):
+        if device is not None:
+            self.device = device
+            # self.vae.to(device=self.device)
+        if dtype is not None:
+            self.dtype = dtype
+
+        # do not move Text Encoders to device, because Text Encoder should be on CPU
+
+    @property
+    def _execution_device(self):
+        r"""
+        Returns the device on which the pipeline's models will be executed. After calling
+        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+        hooks.
+        """
+        if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+            return self.device
+        for module in self.unet.modules():
+            if (
+                hasattr(module, "_hf_hook")
+                and hasattr(module._hf_hook, "execution_device")
+                and module._hf_hook.execution_device is not None
+            ):
+                return torch.device(module._hf_hook.execution_device)
+        return self.device
+
+    def _encode_prompt(
+        self,
+        prompt,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
+        negative_prompt,
+        max_embeddings_multiples,
+        is_sdxl_text_encoder2,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `list(int)`):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+        """
+        batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+        if negative_prompt is None:
+            negative_prompt = [""] * batch_size
+        elif isinstance(negative_prompt, str):
+            negative_prompt = [negative_prompt] * batch_size
+        if batch_size != len(negative_prompt):
+            raise ValueError(
+                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                " the batch size of `prompt`."
+            )
+
+        text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
+            pipe=self,
+            prompt=prompt,
+            uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+            max_embeddings_multiples=max_embeddings_multiples,
+            clip_skip=self.clip_skip,
+            is_sdxl_text_encoder2=is_sdxl_text_encoder2,
+        )
+        bs_embed, seq_len, _ = text_embeddings.shape
+        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)  # ??
+        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+        if text_pool is not None:
+            text_pool = text_pool.repeat(1, num_images_per_prompt)
+            text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
+
+        if do_classifier_free_guidance:
+            bs_embed, seq_len, _ = uncond_embeddings.shape
+            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+            uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+            if uncond_pool is not None:
+                uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
+                uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
+
+            return text_embeddings, text_pool, uncond_embeddings, uncond_pool
+
+        return text_embeddings, text_pool, None, None
+
+    def check_inputs(self, prompt, height, width, strength, callback_steps):
+        if not isinstance(prompt, str) and not isinstance(prompt, list):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if (callback_steps is None) or (
+            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+        ):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
+            )
+
+    def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
+        if is_text2img:
+            return self.scheduler.timesteps.to(device), num_inference_steps
+        else:
+            # get the original timestep using init_timestep
+            offset = self.scheduler.config.get("steps_offset", 0)
+            init_timestep = int(num_inference_steps * strength) + offset
+            init_timestep = min(init_timestep, num_inference_steps)
+
+            t_start = max(num_inference_steps - init_timestep + offset, 0)
+            timesteps = self.scheduler.timesteps[t_start:].to(device)
+            return timesteps, num_inference_steps - t_start
+
+    def run_safety_checker(self, image, device, dtype):
+        if self.safety_checker is not None:
+            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+            image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
+        else:
+            has_nsfw_concept = None
+        return image, has_nsfw_concept
+
+    def decode_latents(self, latents):
+        with torch.no_grad():
+            latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
+
+            # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype)  # torch.float32
+            # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
+            # print("latents dtype:", latents.dtype, "x dtype:", x.dtype)  # torch.float32, torch.float16
+            # self.vae.to("cpu")
+            # self.vae.set_use_memory_efficient_attention_xformers(False)
+            # image = self.vae.decode(latents.to("cpu")).sample
+
+            image = self.vae.decode(latents.to(self.vae.dtype)).sample
+            image = (image / 2 + 0.5).clamp(0, 1)
+            # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+            image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+            return image
+
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
+        if image is None:
+            shape = (
+                batch_size,
+                self.unet.in_channels,
+                height // self.vae_scale_factor,
+                width // self.vae_scale_factor,
+            )
+
+            if latents is None:
+                if device.type == "mps":
+                    # randn does not work reproducibly on mps
+                    latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+                else:
+                    latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+            else:
+                if latents.shape != shape:
+                    raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+                latents = latents.to(device)
+
+            # scale the initial noise by the standard deviation required by the scheduler
+            latents = latents * self.scheduler.init_noise_sigma
+            return latents, None, None
+        else:
+            init_latent_dist = self.vae.encode(image).latent_dist
+            init_latents = init_latent_dist.sample(generator=generator)
+            init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
+            init_latents = torch.cat([init_latents] * batch_size, dim=0)
+            init_latents_orig = init_latents
+            shape = init_latents.shape
+
+            # add noise to latents using the timesteps
+            if device.type == "mps":
+                noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+            else:
+                noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+            latents = self.scheduler.add_noise(init_latents, noise, timestep)
+            return latents, init_latents_orig, noise
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+        mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+        height: int = 512,
+        width: int = 512,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        strength: float = 0.8,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[torch.Generator] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        controlnet=None,
+        controlnet_image=None,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, that will be used as the starting point for the
+                process.
+            mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            height (`int`, *optional*, defaults to 512):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to 512):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            strength (`float`, *optional*, defaults to 0.8):
+                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+                noise will be maximum and the denoising process will run for the full number of iterations specified in
+                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            controlnet (`diffusers.ControlNetModel`, *optional*):
+                A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
+            controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
+                `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
+                inference.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+
+        Returns:
+            `None` if cancelled by `is_cancelled_callback`,
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        if controlnet is not None and controlnet_image is None:
+            raise ValueError("controlnet_image must be provided if controlnet is not None.")
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(prompt, height, width, strength, callback_steps)
+
+        # 2. Define call parameters
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
+        # To simplify the implementation, switch the tokenzer/text encoder and call it twice
+        text_embeddings_list = []
+        text_pool = None
+        uncond_embeddings_list = []
+        uncond_pool = None
+        for i in range(len(self.tokenizers)):
+            self.tokenizer = self.tokenizers[i]
+            self.text_encoder = self.text_encoders[i]
+
+            text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
+                prompt,
+                device,
+                num_images_per_prompt,
+                do_classifier_free_guidance,
+                negative_prompt,
+                max_embeddings_multiples,
+                is_sdxl_text_encoder2=i == 1,
+            )
+            text_embeddings_list.append(text_embeddings)
+            uncond_embeddings_list.append(uncond_embeddings)
+
+            if tp1 is not None:
+                text_pool = tp1
+            if up1 is not None:
+                uncond_pool = up1
+
+        dtype = self.unet.dtype
+
+        # 4. Preprocess image and mask
+        if isinstance(image, PIL.Image.Image):
+            image = preprocess_image(image)
+        if image is not None:
+            image = image.to(device=self.device, dtype=dtype)
+        if isinstance(mask_image, PIL.Image.Image):
+            mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
+        if mask_image is not None:
+            mask = mask_image.to(device=self.device, dtype=dtype)
+            mask = torch.cat([mask] * batch_size * num_images_per_prompt)
+        else:
+            mask = None
+
+        # ControlNet is not working yet in SDXL, but keep the code here for future use
+        if controlnet_image is not None:
+            controlnet_image = prepare_controlnet_image(
+                controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
+            )
+
+        # 5. set timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+        # 6. Prepare latent variables
+        latents, init_latents_orig, noise = self.prepare_latents(
+            image,
+            latent_timestep,
+            batch_size * num_images_per_prompt,
+            height,
+            width,
+            dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # create size embs and concat embeddings for SDXL
+        orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
+        crop_size = torch.zeros_like(orig_size)
+        target_size = orig_size
+        embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
+
+        # make conditionings
+        if do_classifier_free_guidance:
+            text_embeddings = torch.cat(text_embeddings_list, dim=2)
+            uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
+            text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
+
+            cond_vector = torch.cat([text_pool, embs], dim=1)
+            uncond_vector = torch.cat([uncond_pool, embs], dim=1)
+            vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
+        else:
+            text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
+            vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
+
+        # 8. Denoising loop
+        for i, t in enumerate(self.progress_bar(timesteps)):
+            # expand the latents if we are doing classifier free guidance
+            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+            unet_additional_args = {}
+            if controlnet is not None:
+                down_block_res_samples, mid_block_res_sample = controlnet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=text_embeddings,
+                    controlnet_cond=controlnet_image,
+                    conditioning_scale=1.0,
+                    guess_mode=False,
+                    return_dict=False,
+                )
+                unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
+                unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
+
+            # predict the noise residual
+            noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
+            noise_pred = noise_pred.to(dtype)  # U-Net changes dtype in LoRA training
+
+            # perform guidance
+            if do_classifier_free_guidance:
+                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+            # compute the previous noisy sample x_t -> x_t-1
+            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+            if mask is not None:
+                # masking
+                init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
+                latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+            # call the callback, if provided
+            if i % callback_steps == 0:
+                if callback is not None:
+                    callback(i, t, latents)
+                if is_cancelled_callback is not None and is_cancelled_callback():
+                    return None
+
+        return latents
+
+    def latents_to_image(self, latents):
+        # 9. Post-processing
+        image = self.decode_latents(latents.to(self.vae.dtype))
+        image = self.numpy_to_pil(image)
+        return image
+
+    # copy from pil_utils.py
+    def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
+        """
+        Convert a numpy image or a batch of images to a PIL image.
+        """
+        if images.ndim == 3:
+            images = images[None, ...]
+        images = (images * 255).round().astype("uint8")
+        if images.shape[-1] == 1:
+            # special case for grayscale (single channel) images
+            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+        else:
+            pil_images = [Image.fromarray(image) for image in images]
+
+        return pil_images
+
+    def text2img(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: int = 512,
+        width: int = 512,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[torch.Generator] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function for text-to-image generation.
+        Args:
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            height (`int`, *optional*, defaults to 512):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to 512):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        return self.__call__(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            height=height,
+            width=width,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            num_images_per_prompt=num_images_per_prompt,
+            eta=eta,
+            generator=generator,
+            latents=latents,
+            max_embeddings_multiples=max_embeddings_multiples,
+            output_type=output_type,
+            return_dict=return_dict,
+            callback=callback,
+            is_cancelled_callback=is_cancelled_callback,
+            callback_steps=callback_steps,
+        )
+
+    def img2img(
+        self,
+        image: Union[torch.FloatTensor, PIL.Image.Image],
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        strength: float = 0.8,
+        num_inference_steps: Optional[int] = 50,
+        guidance_scale: Optional[float] = 7.5,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: Optional[float] = 0.0,
+        generator: Optional[torch.Generator] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function for image-to-image generation.
+        Args:
+            image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, that will be used as the starting point for the
+                process.
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            strength (`float`, *optional*, defaults to 0.8):
+                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+                noise will be maximum and the denoising process will run for the full number of iterations specified in
+                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference. This parameter will be modulated by `strength`.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        return self.__call__(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            image=image,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            strength=strength,
+            num_images_per_prompt=num_images_per_prompt,
+            eta=eta,
+            generator=generator,
+            max_embeddings_multiples=max_embeddings_multiples,
+            output_type=output_type,
+            return_dict=return_dict,
+            callback=callback,
+            is_cancelled_callback=is_cancelled_callback,
+            callback_steps=callback_steps,
+        )
+
+    def inpaint(
+        self,
+        image: Union[torch.FloatTensor, PIL.Image.Image],
+        mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        strength: float = 0.8,
+        num_inference_steps: Optional[int] = 50,
+        guidance_scale: Optional[float] = 7.5,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: Optional[float] = 0.0,
+        generator: Optional[torch.Generator] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: int = 1,
+    ):
+        r"""
+        Function for inpaint.
+        Args:
+            image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, that will be used as the starting point for the
+                process. This is the image whose masked region will be inpainted.
+            mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            strength (`float`, *optional*, defaults to 0.8):
+                Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+                is 1, the denoising process will be run on the masked area for the full number of iterations specified
+                in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
+                noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+                the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+                deterministic.
+            max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+                The max multiple length of prompt embeddings compared to the max output length of text encoder.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            is_cancelled_callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. If the function returns
+                `True`, the inference will be cancelled.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+        return self.__call__(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            image=image,
+            mask_image=mask_image,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            strength=strength,
+            num_images_per_prompt=num_images_per_prompt,
+            eta=eta,
+            generator=generator,
+            max_embeddings_multiples=max_embeddings_multiples,
+            output_type=output_type,
+            return_dict=return_dict,
+            callback=callback,
+            is_cancelled_callback=is_cancelled_callback,
+            callback_steps=callback_steps,
+        )
diff --git a/external/llite/library/sdxl_model_util.py b/external/llite/library/sdxl_model_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..55eb910dea436212a43a6e334b350957292d851b
--- /dev/null
+++ b/external/llite/library/sdxl_model_util.py
@@ -0,0 +1,578 @@
+import torch
+from accelerate import init_empty_weights
+from accelerate.utils.modeling import set_module_tensor_to_device
+from safetensors.torch import load_file, save_file
+from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
+from typing import List
+from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
+from external.llite.library import model_util
+from external.llite.library import sdxl_original_unet
+
+
+VAE_SCALE_FACTOR = 0.13025
+MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
+
+# Diffusersの設定を読み込むための参照モデル
+DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
+
+DIFFUSERS_SDXL_UNET_CONFIG = {
+    "act_fn": "silu",
+    "addition_embed_type": "text_time",
+    "addition_embed_type_num_heads": 64,
+    "addition_time_embed_dim": 256,
+    "attention_head_dim": [5, 10, 20],
+    "block_out_channels": [320, 640, 1280],
+    "center_input_sample": False,
+    "class_embed_type": None,
+    "class_embeddings_concat": False,
+    "conv_in_kernel": 3,
+    "conv_out_kernel": 3,
+    "cross_attention_dim": 2048,
+    "cross_attention_norm": None,
+    "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
+    "downsample_padding": 1,
+    "dual_cross_attention": False,
+    "encoder_hid_dim": None,
+    "encoder_hid_dim_type": None,
+    "flip_sin_to_cos": True,
+    "freq_shift": 0,
+    "in_channels": 4,
+    "layers_per_block": 2,
+    "mid_block_only_cross_attention": None,
+    "mid_block_scale_factor": 1,
+    "mid_block_type": "UNetMidBlock2DCrossAttn",
+    "norm_eps": 1e-05,
+    "norm_num_groups": 32,
+    "num_attention_heads": None,
+    "num_class_embeds": None,
+    "only_cross_attention": False,
+    "out_channels": 4,
+    "projection_class_embeddings_input_dim": 2816,
+    "resnet_out_scale_factor": 1.0,
+    "resnet_skip_time_act": False,
+    "resnet_time_scale_shift": "default",
+    "sample_size": 128,
+    "time_cond_proj_dim": None,
+    "time_embedding_act_fn": None,
+    "time_embedding_dim": None,
+    "time_embedding_type": "positional",
+    "timestep_post_act": None,
+    "transformer_layers_per_block": [1, 2, 10],
+    "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
+    "upcast_attention": False,
+    "use_linear_projection": True,
+}
+
+
+def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
+    SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
+
+    # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
+    # logit_scaleはcheckpointの保存時に使用する
+    def convert_key(key):
+        # common conversion
+        key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
+        key = key.replace(SDXL_KEY_PREFIX, "text_model.")
+
+        if "resblocks" in key:
+            # resblocks conversion
+            key = key.replace(".resblocks.", ".layers.")
+            if ".ln_" in key:
+                key = key.replace(".ln_", ".layer_norm")
+            elif ".mlp." in key:
+                key = key.replace(".c_fc.", ".fc1.")
+                key = key.replace(".c_proj.", ".fc2.")
+            elif ".attn.out_proj" in key:
+                key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
+            elif ".attn.in_proj" in key:
+                key = None  # 特殊なので後で処理する
+            else:
+                raise ValueError(f"unexpected key in SD: {key}")
+        elif ".positional_embedding" in key:
+            key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
+        elif ".text_projection" in key:
+            key = key.replace("text_model.text_projection", "text_projection.weight")
+        elif ".logit_scale" in key:
+            key = None  # 後で処理する
+        elif ".token_embedding" in key:
+            key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
+        elif ".ln_final" in key:
+            key = key.replace(".ln_final", ".final_layer_norm")
+        # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
+        elif ".embeddings.position_ids" in key:
+            key = None  # remove this key: make position_ids by ourselves
+        return key
+
+    keys = list(checkpoint.keys())
+    new_sd = {}
+    for key in keys:
+        new_key = convert_key(key)
+        if new_key is None:
+            continue
+        new_sd[new_key] = checkpoint[key]
+
+    # attnの変換
+    for key in keys:
+        if ".resblocks" in key and ".attn.in_proj_" in key:
+            # 三つに分割
+            values = torch.chunk(checkpoint[key], 3)
+
+            key_suffix = ".weight" if "weight" in key else ".bias"
+            key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
+            key_pfx = key_pfx.replace("_weight", "")
+            key_pfx = key_pfx.replace("_bias", "")
+            key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
+            new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
+            new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
+            new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
+
+    # original SD にはないので、position_idsを追加
+    position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
+    new_sd["text_model.embeddings.position_ids"] = position_ids
+
+    # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
+    logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
+
+    # temporary workaround for text_projection.weight.weight for Playground-v2
+    if "text_projection.weight.weight" in new_sd:
+        print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
+        new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
+        del new_sd["text_projection.weight.weight"]
+
+    return new_sd, logit_scale
+
+
+# load state_dict without allocating new tensors
+def _load_state_dict_on_device(model, state_dict, device, dtype=None):
+    # dtype will use fp32 as default
+    missing_keys = list(model.state_dict().keys() - state_dict.keys())
+    unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
+
+    # similar to model.load_state_dict()
+    if not missing_keys and not unexpected_keys:
+        for k in list(state_dict.keys()):
+            set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
+        return "<All keys matched successfully>"
+
+    # error_msgs
+    error_msgs: List[str] = []
+    if missing_keys:
+        error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
+    if unexpected_keys:
+        error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
+
+    raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
+
+
+def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
+    # model_version is reserved for future use
+    # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
+
+    # Load the state dict
+    if model_util.is_safetensors(ckpt_path):
+        checkpoint = None
+        try:
+            state_dict = load_file(ckpt_path, device=map_location)
+        except:
+            state_dict = load_file(ckpt_path)  # prevent device invalid Error
+        epoch = None
+        global_step = None
+    else:
+        checkpoint = torch.load(ckpt_path, map_location=map_location)
+        if "state_dict" in checkpoint:
+            state_dict = checkpoint["state_dict"]
+            epoch = checkpoint.get("epoch", 0)
+            global_step = checkpoint.get("global_step", 0)
+        else:
+            state_dict = checkpoint
+            epoch = 0
+            global_step = 0
+        checkpoint = None
+
+    # U-Net
+    print("building U-Net")
+    with init_empty_weights():
+        unet = sdxl_original_unet.SdxlUNet2DConditionModel()
+
+    print("loading U-Net from checkpoint")
+    unet_sd = {}
+    for k in list(state_dict.keys()):
+        if k.startswith("model.diffusion_model."):
+            unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
+    info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
+    print("U-Net: ", info)
+
+    # Text Encoders
+    print("building text encoders")
+
+    # Text Encoder 1 is same to Stability AI's SDXL
+    text_model1_cfg = CLIPTextConfig(
+        vocab_size=49408,
+        hidden_size=768,
+        intermediate_size=3072,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        max_position_embeddings=77,
+        hidden_act="quick_gelu",
+        layer_norm_eps=1e-05,
+        dropout=0.0,
+        attention_dropout=0.0,
+        initializer_range=0.02,
+        initializer_factor=1.0,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        model_type="clip_text_model",
+        projection_dim=768,
+        # torch_dtype="float32",
+        # transformers_version="4.25.0.dev0",
+    )
+    with init_empty_weights():
+        text_model1 = CLIPTextModel._from_config(text_model1_cfg)
+
+    # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
+    # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
+    text_model2_cfg = CLIPTextConfig(
+        vocab_size=49408,
+        hidden_size=1280,
+        intermediate_size=5120,
+        num_hidden_layers=32,
+        num_attention_heads=20,
+        max_position_embeddings=77,
+        hidden_act="gelu",
+        layer_norm_eps=1e-05,
+        dropout=0.0,
+        attention_dropout=0.0,
+        initializer_range=0.02,
+        initializer_factor=1.0,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        model_type="clip_text_model",
+        projection_dim=1280,
+        # torch_dtype="float32",
+        # transformers_version="4.25.0.dev0",
+    )
+    with init_empty_weights():
+        text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
+
+    print("loading text encoders from checkpoint")
+    te1_sd = {}
+    te2_sd = {}
+    for k in list(state_dict.keys()):
+        if k.startswith("conditioner.embedders.0.transformer."):
+            te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
+        elif k.startswith("conditioner.embedders.1.model."):
+            te2_sd[k] = state_dict.pop(k)
+
+    # 一部のposition_idsがないモデルへの対応 / add position_ids for some models
+    if "text_model.embeddings.position_ids" not in te1_sd:
+        te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
+
+    info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location)  # remain fp32
+    print("text encoder 1:", info1)
+
+    converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
+    info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location)  # remain fp32
+    print("text encoder 2:", info2)
+
+    # prepare vae
+    print("building VAE")
+    vae_config = model_util.create_vae_diffusers_config()
+    with init_empty_weights():
+        vae = AutoencoderKL(**vae_config)
+
+    print("loading VAE from checkpoint")
+    converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
+    info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
+    print("VAE:", info)
+
+    ckpt_info = (epoch, global_step) if epoch is not None else None
+    return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
+
+
+def make_unet_conversion_map():
+    unet_conversion_map_layer = []
+
+    for i in range(3):  # num_blocks is 3 in sdxl
+        # loop over downblocks/upblocks
+        for j in range(2):
+            # loop over resnets/attentions for downblocks
+            hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+            sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+            unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+            if i < 3:
+                # no attention layers in down_blocks.3
+                hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+                sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+                unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+        for j in range(3):
+            # loop over resnets/attentions for upblocks
+            hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+            sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+            unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+            # if i > 0: commentout for sdxl
+            # no attention layers in up_blocks.0
+            hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+            sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+            unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+        if i < 3:
+            # no downsample in down_blocks.3
+            hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+            sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+            unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+            # no upsample in up_blocks.3
+            hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+            sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}."  # change for sdxl
+            unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+    hf_mid_atn_prefix = "mid_block.attentions.0."
+    sd_mid_atn_prefix = "middle_block.1."
+    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+    for j in range(2):
+        hf_mid_res_prefix = f"mid_block.resnets.{j}."
+        sd_mid_res_prefix = f"middle_block.{2*j}."
+        unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+    unet_conversion_map_resnet = [
+        # (stable-diffusion, HF Diffusers)
+        ("in_layers.0.", "norm1."),
+        ("in_layers.2.", "conv1."),
+        ("out_layers.0.", "norm2."),
+        ("out_layers.3.", "conv2."),
+        ("emb_layers.1.", "time_emb_proj."),
+        ("skip_connection.", "conv_shortcut."),
+    ]
+
+    unet_conversion_map = []
+    for sd, hf in unet_conversion_map_layer:
+        if "resnets" in hf:
+            for sd_res, hf_res in unet_conversion_map_resnet:
+                unet_conversion_map.append((sd + sd_res, hf + hf_res))
+        else:
+            unet_conversion_map.append((sd, hf))
+
+    for j in range(2):
+        hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
+        sd_time_embed_prefix = f"time_embed.{j*2}."
+        unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
+
+    for j in range(2):
+        hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
+        sd_label_embed_prefix = f"label_emb.0.{j*2}."
+        unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
+
+    unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
+    unet_conversion_map.append(("out.0.", "conv_norm_out."))
+    unet_conversion_map.append(("out.2.", "conv_out."))
+
+    return unet_conversion_map
+
+
+def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
+    unet_conversion_map = make_unet_conversion_map()
+
+    conversion_map = {hf: sd for sd, hf in unet_conversion_map}
+    return convert_unet_state_dict(du_sd, conversion_map)
+
+
+def convert_unet_state_dict(src_sd, conversion_map):
+    converted_sd = {}
+    for src_key, value in src_sd.items():
+        # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
+        src_key_fragments = src_key.split(".")[:-1]  # remove weight/bias
+        while len(src_key_fragments) > 0:
+            src_key_prefix = ".".join(src_key_fragments) + "."
+            if src_key_prefix in conversion_map:
+                converted_prefix = conversion_map[src_key_prefix]
+                converted_key = converted_prefix + src_key[len(src_key_prefix) :]
+                converted_sd[converted_key] = value
+                break
+            src_key_fragments.pop(-1)
+        assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
+
+    return converted_sd
+
+
+def convert_sdxl_unet_state_dict_to_diffusers(sd):
+    unet_conversion_map = make_unet_conversion_map()
+
+    conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
+    return convert_unet_state_dict(sd, conversion_dict)
+
+
+def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
+    def convert_key(key):
+        # position_idsの除去
+        if ".position_ids" in key:
+            return None
+
+        # common
+        key = key.replace("text_model.encoder.", "transformer.")
+        key = key.replace("text_model.", "")
+        if "layers" in key:
+            # resblocks conversion
+            key = key.replace(".layers.", ".resblocks.")
+            if ".layer_norm" in key:
+                key = key.replace(".layer_norm", ".ln_")
+            elif ".mlp." in key:
+                key = key.replace(".fc1.", ".c_fc.")
+                key = key.replace(".fc2.", ".c_proj.")
+            elif ".self_attn.out_proj" in key:
+                key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
+            elif ".self_attn." in key:
+                key = None  # 特殊なので後で処理する
+            else:
+                raise ValueError(f"unexpected key in DiffUsers model: {key}")
+        elif ".position_embedding" in key:
+            key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
+        elif ".token_embedding" in key:
+            key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
+        elif "text_projection" in key:  # no dot in key
+            key = key.replace("text_projection.weight", "text_projection")
+        elif "final_layer_norm" in key:
+            key = key.replace("final_layer_norm", "ln_final")
+        return key
+
+    keys = list(checkpoint.keys())
+    new_sd = {}
+    for key in keys:
+        new_key = convert_key(key)
+        if new_key is None:
+            continue
+        new_sd[new_key] = checkpoint[key]
+
+    # attnの変換
+    for key in keys:
+        if "layers" in key and "q_proj" in key:
+            # 三つを結合
+            key_q = key
+            key_k = key.replace("q_proj", "k_proj")
+            key_v = key.replace("q_proj", "v_proj")
+
+            value_q = checkpoint[key_q]
+            value_k = checkpoint[key_k]
+            value_v = checkpoint[key_v]
+            value = torch.cat([value_q, value_k, value_v])
+
+            new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
+            new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
+            new_sd[new_key] = value
+
+    if logit_scale is not None:
+        new_sd["logit_scale"] = logit_scale
+
+    return new_sd
+
+
+def save_stable_diffusion_checkpoint(
+    output_file,
+    text_encoder1,
+    text_encoder2,
+    unet,
+    epochs,
+    steps,
+    ckpt_info,
+    vae,
+    logit_scale,
+    metadata,
+    save_dtype=None,
+):
+    state_dict = {}
+
+    def update_sd(prefix, sd):
+        for k, v in sd.items():
+            key = prefix + k
+            if save_dtype is not None:
+                v = v.detach().clone().to("cpu").to(save_dtype)
+            state_dict[key] = v
+
+    # Convert the UNet model
+    update_sd("model.diffusion_model.", unet.state_dict())
+
+    # Convert the text encoders
+    update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
+
+    text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
+    update_sd("conditioner.embedders.1.model.", text_enc2_dict)
+
+    # Convert the VAE
+    vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
+    update_sd("first_stage_model.", vae_dict)
+
+    # Put together new checkpoint
+    key_count = len(state_dict.keys())
+    new_ckpt = {"state_dict": state_dict}
+
+    # epoch and global_step are sometimes not int
+    if ckpt_info is not None:
+        epochs += ckpt_info[0]
+        steps += ckpt_info[1]
+
+    new_ckpt["epoch"] = epochs
+    new_ckpt["global_step"] = steps
+
+    if model_util.is_safetensors(output_file):
+        save_file(state_dict, output_file, metadata)
+    else:
+        torch.save(new_ckpt, output_file)
+
+    return key_count
+
+
+def save_diffusers_checkpoint(
+    output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
+):
+    from diffusers import StableDiffusionXLPipeline
+
+    # convert U-Net
+    unet_sd = unet.state_dict()
+    du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
+
+    diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
+    if save_dtype is not None:
+        diffusers_unet.to(save_dtype)
+    diffusers_unet.load_state_dict(du_unet_sd)
+
+    # create pipeline to save
+    if pretrained_model_name_or_path is None:
+        pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
+
+    scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
+    tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
+    tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
+    if vae is None:
+        vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+
+    # prevent local path from being saved
+    def remove_name_or_path(model):
+        if hasattr(model, "config"):
+            model.config._name_or_path = None
+            model.config._name_or_path = None
+
+    remove_name_or_path(diffusers_unet)
+    remove_name_or_path(text_encoder1)
+    remove_name_or_path(text_encoder2)
+    remove_name_or_path(scheduler)
+    remove_name_or_path(tokenizer1)
+    remove_name_or_path(tokenizer2)
+    remove_name_or_path(vae)
+
+    pipeline = StableDiffusionXLPipeline(
+        unet=diffusers_unet,
+        text_encoder=text_encoder1,
+        text_encoder_2=text_encoder2,
+        vae=vae,
+        scheduler=scheduler,
+        tokenizer=tokenizer1,
+        tokenizer_2=tokenizer2,
+    )
+    if save_dtype is not None:
+        pipeline.to(None, save_dtype)
+    pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
diff --git a/external/llite/library/sdxl_original_unet.py b/external/llite/library/sdxl_original_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..babda8ec58a18bdf99ad38bc83f55341473caa21
--- /dev/null
+++ b/external/llite/library/sdxl_original_unet.py
@@ -0,0 +1,1281 @@
+# Diffusersのコードをベースとした sd_xl_baseのU-Net
+# state dictの形式をSDXLに合わせてある
+
+"""
+      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        adm_in_channels: 2816
+        num_classes: sequential
+        use_checkpoint: True
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [4, 2]
+        num_res_blocks: 2
+        channel_mult: [1, 2, 4]
+        num_head_channels: 64
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: [1, 2, 10]  # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
+        context_dim: 2048
+        spatial_transformer_attn_type: softmax-xformers
+        legacy: False
+"""
+
+import math
+from types import SimpleNamespace
+from typing import Any, Optional
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import functional as F
+from einops import rearrange
+
+
+IN_CHANNELS: int = 4
+OUT_CHANNELS: int = 4
+ADM_IN_CHANNELS: int = 2816
+CONTEXT_DIM: int = 2048
+MODEL_CHANNELS: int = 320
+TIME_EMBED_DIM = 320 * 4
+
+USE_REENTRANT = True
+
+# region memory efficient attention
+
+# FlashAttentionを使うCrossAttention
+# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
+# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
+
+# constants
+
+EPSILON = 1e-6
+
+# helper functions
+
+
+def exists(val):
+    return val is not None
+
+
+def default(val, d):
+    return val if exists(val) else d
+
+
+# flash attention forwards and backwards
+
+# https://arxiv.org/abs/2205.14135
+
+
+class FlashAttentionFunction(torch.autograd.Function):
+    @staticmethod
+    @torch.no_grad()
+    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
+        """Algorithm 2 in the paper"""
+
+        device = q.device
+        dtype = q.dtype
+        max_neg_value = -torch.finfo(q.dtype).max
+        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+        o = torch.zeros_like(q)
+        all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
+        all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
+
+        scale = q.shape[-1] ** -0.5
+
+        if not exists(mask):
+            mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
+        else:
+            mask = rearrange(mask, "b n -> b 1 1 n")
+            mask = mask.split(q_bucket_size, dim=-1)
+
+        row_splits = zip(
+            q.split(q_bucket_size, dim=-2),
+            o.split(q_bucket_size, dim=-2),
+            mask,
+            all_row_sums.split(q_bucket_size, dim=-2),
+            all_row_maxes.split(q_bucket_size, dim=-2),
+        )
+
+        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
+            q_start_index = ind * q_bucket_size - qk_len_diff
+
+            col_splits = zip(
+                k.split(k_bucket_size, dim=-2),
+                v.split(k_bucket_size, dim=-2),
+            )
+
+            for k_ind, (kc, vc) in enumerate(col_splits):
+                k_start_index = k_ind * k_bucket_size
+
+                attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
+
+                if exists(row_mask):
+                    attn_weights.masked_fill_(~row_mask, max_neg_value)
+
+                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
+                        q_start_index - k_start_index + 1
+                    )
+                    attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+                block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
+                attn_weights -= block_row_maxes
+                exp_weights = torch.exp(attn_weights)
+
+                if exists(row_mask):
+                    exp_weights.masked_fill_(~row_mask, 0.0)
+
+                block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
+
+                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
+
+                exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
+
+                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
+                exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
+
+                new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
+
+                oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
+
+                row_maxes.copy_(new_row_maxes)
+                row_sums.copy_(new_row_sums)
+
+        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
+        ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
+
+        return o
+
+    @staticmethod
+    @torch.no_grad()
+    def backward(ctx, do):
+        """Algorithm 4 in the paper"""
+
+        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
+        q, k, v, o, l, m = ctx.saved_tensors
+
+        device = q.device
+
+        max_neg_value = -torch.finfo(q.dtype).max
+        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+        dq = torch.zeros_like(q)
+        dk = torch.zeros_like(k)
+        dv = torch.zeros_like(v)
+
+        row_splits = zip(
+            q.split(q_bucket_size, dim=-2),
+            o.split(q_bucket_size, dim=-2),
+            do.split(q_bucket_size, dim=-2),
+            mask,
+            l.split(q_bucket_size, dim=-2),
+            m.split(q_bucket_size, dim=-2),
+            dq.split(q_bucket_size, dim=-2),
+        )
+
+        for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
+            q_start_index = ind * q_bucket_size - qk_len_diff
+
+            col_splits = zip(
+                k.split(k_bucket_size, dim=-2),
+                v.split(k_bucket_size, dim=-2),
+                dk.split(k_bucket_size, dim=-2),
+                dv.split(k_bucket_size, dim=-2),
+            )
+
+            for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
+                k_start_index = k_ind * k_bucket_size
+
+                attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
+
+                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
+                        q_start_index - k_start_index + 1
+                    )
+                    attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+                exp_attn_weights = torch.exp(attn_weights - mc)
+
+                if exists(row_mask):
+                    exp_attn_weights.masked_fill_(~row_mask, 0.0)
+
+                p = exp_attn_weights / lc
+
+                dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
+                dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
+
+                D = (doc * oc).sum(dim=-1, keepdims=True)
+                ds = p * scale * (dp - D)
+
+                dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
+                dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
+
+                dqc.add_(dq_chunk)
+                dkc.add_(dk_chunk)
+                dvc.add_(dv_chunk)
+
+        return dq, dk, dv, None, None, None, None
+
+
+# endregion
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+    return next(parameter.parameters()).dtype
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+    return next(parameter.parameters()).device
+
+
+def get_timestep_embedding(
+    timesteps: torch.Tensor,
+    embedding_dim: int,
+    downscale_freq_shift: float = 1,
+    scale: float = 1,
+    max_period: int = 10000,
+):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+    embeddings. :return: an [N x dim] Tensor of positional embeddings.
+    """
+    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+    half_dim = embedding_dim // 2
+    exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
+    exponent = exponent / (half_dim - downscale_freq_shift)
+
+    emb = torch.exp(exponent)
+    emb = timesteps[:, None].float() * emb[None, :]
+
+    # scale embeddings
+    emb = scale * emb
+
+    # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
+    emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
+
+    # zero pad
+    if embedding_dim % 2 == 1:
+        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+    return emb
+
+
+# Deep Shrink: We do not common this function, because minimize dependencies.
+def resize_like(x, target, mode="bicubic", align_corners=False):
+    org_dtype = x.dtype
+    if org_dtype == torch.bfloat16:
+        x = x.to(torch.float32)
+
+    if x.shape[-2:] != target.shape[-2:]:
+        if mode == "nearest":
+            x = F.interpolate(x, size=target.shape[-2:], mode=mode)
+        else:
+            x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
+
+    if org_dtype == torch.bfloat16:
+        x = x.to(org_dtype)
+    return x
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        if self.weight.dtype != torch.float32:
+            return super().forward(x)
+        return super().forward(x.float()).type(x.dtype)
+
+
+class ResnetBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        self.in_layers = nn.Sequential(
+            GroupNorm32(32, in_channels),
+            nn.SiLU(),
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
+        )
+
+        self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
+
+        self.out_layers = nn.Sequential(
+            GroupNorm32(32, out_channels),
+            nn.SiLU(),
+            nn.Identity(),  # to make state_dict compatible with original model
+            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
+        )
+
+        if in_channels != out_channels:
+            self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+        else:
+            self.skip_connection = nn.Identity()
+
+        self.gradient_checkpointing = False
+
+    def forward_body(self, x, emb):
+        h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        h = h + emb_out[:, :, None, None]
+        h = self.out_layers(h)
+        x = self.skip_connection(x)
+        return x + h
+
+    def forward(self, x, emb):
+        if self.training and self.gradient_checkpointing:
+            # print("ResnetBlock2D: gradient_checkpointing")
+
+            def create_custom_forward(func):
+                def custom_forward(*inputs):
+                    return func(*inputs)
+
+                return custom_forward
+
+            x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
+        else:
+            x = self.forward_body(x, emb)
+
+        return x
+
+
+class Downsample2D(nn.Module):
+    def __init__(self, channels, out_channels):
+        super().__init__()
+
+        self.channels = channels
+        self.out_channels = out_channels
+
+        self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
+
+        self.gradient_checkpointing = False
+
+    def forward_body(self, hidden_states):
+        assert hidden_states.shape[1] == self.channels
+        hidden_states = self.op(hidden_states)
+
+        return hidden_states
+
+    def forward(self, hidden_states):
+        if self.training and self.gradient_checkpointing:
+            # print("Downsample2D: gradient_checkpointing")
+
+            def create_custom_forward(func):
+                def custom_forward(*inputs):
+                    return func(*inputs)
+
+                return custom_forward
+
+            hidden_states = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
+            )
+        else:
+            hidden_states = self.forward_body(hidden_states)
+
+        return hidden_states
+
+
+class CrossAttention(nn.Module):
+    def __init__(
+        self,
+        query_dim: int,
+        cross_attention_dim: Optional[int] = None,
+        heads: int = 8,
+        dim_head: int = 64,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+        inner_dim = dim_head * heads
+        cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+        self.upcast_attention = upcast_attention
+
+        self.scale = dim_head**-0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
+
+        self.to_out = nn.ModuleList([])
+        self.to_out.append(nn.Linear(inner_dim, query_dim))
+        # no dropout here
+
+        self.use_memory_efficient_attention_xformers = False
+        self.use_memory_efficient_attention_mem_eff = False
+        self.use_sdpa = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        self.use_memory_efficient_attention_xformers = xformers
+        self.use_memory_efficient_attention_mem_eff = mem_eff
+
+    def set_use_sdpa(self, sdpa):
+        self.use_sdpa = sdpa
+
+    def reshape_heads_to_batch_dim(self, tensor):
+        batch_size, seq_len, dim = tensor.shape
+        head_size = self.heads
+        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+        return tensor
+
+    def reshape_batch_dim_to_heads(self, tensor):
+        batch_size, seq_len, dim = tensor.shape
+        head_size = self.heads
+        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+        return tensor
+
+    def forward(self, hidden_states, context=None, mask=None):
+        if self.use_memory_efficient_attention_xformers:
+            return self.forward_memory_efficient_xformers(hidden_states, context, mask)
+        if self.use_memory_efficient_attention_mem_eff:
+            return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
+        if self.use_sdpa:
+            return self.forward_sdpa(hidden_states, context, mask)
+
+        query = self.to_q(hidden_states)
+        context = context if context is not None else hidden_states
+        key = self.to_k(context)
+        value = self.to_v(context)
+
+        query = self.reshape_heads_to_batch_dim(query)
+        key = self.reshape_heads_to_batch_dim(key)
+        value = self.reshape_heads_to_batch_dim(value)
+
+        hidden_states = self._attention(query, key, value)
+
+        # linear proj
+        hidden_states = self.to_out[0](hidden_states)
+        # hidden_states = self.to_out[1](hidden_states)     # no dropout
+        return hidden_states
+
+    def _attention(self, query, key, value):
+        if self.upcast_attention:
+            query = query.float()
+            key = key.float()
+
+        attention_scores = torch.baddbmm(
+            torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+            query,
+            key.transpose(-1, -2),
+            beta=0,
+            alpha=self.scale,
+        )
+        attention_probs = attention_scores.softmax(dim=-1)
+
+        # cast back to the original dtype
+        attention_probs = attention_probs.to(value.dtype)
+
+        # compute attention output
+        hidden_states = torch.bmm(attention_probs, value)
+
+        # reshape hidden_states
+        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+        return hidden_states
+
+    # TODO support Hypernetworks
+    def forward_memory_efficient_xformers(self, x, context=None, mask=None):
+        import xformers.ops
+
+        h = self.heads
+        q_in = self.to_q(x)
+        context = context if context is not None else x
+        context = context.to(x.dtype)
+        k_in = self.to_k(context)
+        v_in = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
+        del q_in, k_in, v_in
+
+        q = q.contiguous()
+        k = k.contiguous()
+        v = v.contiguous()
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)  # 最適なのを選んでくれる
+        del q, k, v
+
+        out = rearrange(out, "b n h d -> b n (h d)", h=h)
+
+        out = self.to_out[0](out)
+        return out
+
+    def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
+        flash_func = FlashAttentionFunction
+
+        q_bucket_size = 512
+        k_bucket_size = 1024
+
+        h = self.heads
+        q = self.to_q(x)
+        context = context if context is not None else x
+        context = context.to(x.dtype)
+        k = self.to_k(context)
+        v = self.to_v(context)
+        del context, x
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+        out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
+
+        out = rearrange(out, "b h n d -> b n (h d)")
+
+        out = self.to_out[0](out)
+        return out
+
+    def forward_sdpa(self, x, context=None, mask=None):
+        h = self.heads
+        q_in = self.to_q(x)
+        context = context if context is not None else x
+        context = context.to(x.dtype)
+        k_in = self.to_k(context)
+        v_in = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
+        del q_in, k_in, v_in
+
+        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+
+        out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+        out = self.to_out[0](out)
+        return out
+
+
+# feedforward
+class GEGLU(nn.Module):
+    r"""
+    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+    Parameters:
+        dim_in (`int`): The number of channels in the input.
+        dim_out (`int`): The number of channels in the output.
+    """
+
+    def __init__(self, dim_in: int, dim_out: int):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def gelu(self, gate):
+        if gate.device.type != "mps":
+            return F.gelu(gate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states):
+        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+        return hidden_states * self.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+    ):
+        super().__init__()
+        inner_dim = int(dim * 4)  # mult is always 4
+
+        self.net = nn.ModuleList([])
+        # project in
+        self.net.append(GEGLU(dim, inner_dim))
+        # project dropout
+        self.net.append(nn.Identity())  # nn.Dropout(0)) # dummy for dropout with 0
+        # project out
+        self.net.append(nn.Linear(inner_dim, dim))
+
+    def forward(self, hidden_states):
+        for module in self.net:
+            hidden_states = module(hidden_states)
+        return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+    def __init__(
+        self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
+    ):
+        super().__init__()
+
+        self.gradient_checkpointing = False
+
+        # 1. Self-Attn
+        self.attn1 = CrossAttention(
+            query_dim=dim,
+            cross_attention_dim=None,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            upcast_attention=upcast_attention,
+        )
+        self.ff = FeedForward(dim)
+
+        # 2. Cross-Attn
+        self.attn2 = CrossAttention(
+            query_dim=dim,
+            cross_attention_dim=cross_attention_dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            upcast_attention=upcast_attention,
+        )
+
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+
+        # 3. Feed-forward
+        self.norm3 = nn.LayerNorm(dim)
+
+    def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
+        self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
+        self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa: bool):
+        self.attn1.set_use_sdpa(sdpa)
+        self.attn2.set_use_sdpa(sdpa)
+
+    def forward_body(self, hidden_states, context=None, timestep=None):
+        # 1. Self-Attention
+        norm_hidden_states = self.norm1(hidden_states)
+
+        hidden_states = self.attn1(norm_hidden_states) + hidden_states
+
+        # 2. Cross-Attention
+        norm_hidden_states = self.norm2(hidden_states)
+        hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
+
+        # 3. Feed-forward
+        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+        return hidden_states
+
+    def forward(self, hidden_states, context=None, timestep=None):
+        if self.training and self.gradient_checkpointing:
+            # print("BasicTransformerBlock: checkpointing")
+
+            def create_custom_forward(func):
+                def custom_forward(*inputs):
+                    return func(*inputs)
+
+                return custom_forward
+
+            output = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
+            )
+        else:
+            output = self.forward_body(hidden_states, context, timestep)
+
+        return output
+
+
+class Transformer2DModel(nn.Module):
+    def __init__(
+        self,
+        num_attention_heads: int = 16,
+        attention_head_dim: int = 88,
+        in_channels: Optional[int] = None,
+        cross_attention_dim: Optional[int] = None,
+        use_linear_projection: bool = False,
+        upcast_attention: bool = False,
+        num_transformer_layers: int = 1,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.num_attention_heads = num_attention_heads
+        self.attention_head_dim = attention_head_dim
+        inner_dim = num_attention_heads * attention_head_dim
+        self.use_linear_projection = use_linear_projection
+
+        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+        # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
+
+        if use_linear_projection:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+        else:
+            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+        blocks = []
+        for _ in range(num_transformer_layers):
+            blocks.append(
+                BasicTransformerBlock(
+                    inner_dim,
+                    num_attention_heads,
+                    attention_head_dim,
+                    cross_attention_dim=cross_attention_dim,
+                    upcast_attention=upcast_attention,
+                )
+            )
+
+        self.transformer_blocks = nn.ModuleList(blocks)
+
+        if use_linear_projection:
+            self.proj_out = nn.Linear(in_channels, inner_dim)
+        else:
+            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+        self.gradient_checkpointing = False
+
+    def set_use_memory_efficient_attention(self, xformers, mem_eff):
+        for transformer in self.transformer_blocks:
+            transformer.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa):
+        for transformer in self.transformer_blocks:
+            transformer.set_use_sdpa(sdpa)
+
+    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
+        # 1. Input
+        batch, _, height, weight = hidden_states.shape
+        residual = hidden_states
+
+        hidden_states = self.norm(hidden_states)
+        if not self.use_linear_projection:
+            hidden_states = self.proj_in(hidden_states)
+            inner_dim = hidden_states.shape[1]
+            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+        else:
+            inner_dim = hidden_states.shape[1]
+            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+            hidden_states = self.proj_in(hidden_states)
+
+        # 2. Blocks
+        for block in self.transformer_blocks:
+            hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
+
+        # 3. Output
+        if not self.use_linear_projection:
+            hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+            hidden_states = self.proj_out(hidden_states)
+        else:
+            hidden_states = self.proj_out(hidden_states)
+            hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+        output = hidden_states + residual
+
+        return output
+
+
+class Upsample2D(nn.Module):
+    def __init__(self, channels, out_channels):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels
+        self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+        self.gradient_checkpointing = False
+
+    def forward_body(self, hidden_states, output_size=None):
+        assert hidden_states.shape[1] == self.channels
+
+        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+        # https://github.com/pytorch/pytorch/issues/86679
+        dtype = hidden_states.dtype
+        if dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(torch.float32)
+
+        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+        if hidden_states.shape[0] >= 64:
+            hidden_states = hidden_states.contiguous()
+
+        # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
+        if output_size is None:
+            hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+        else:
+            hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+        # If the input is bfloat16, we cast back to bfloat16
+        if dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(dtype)
+
+        hidden_states = self.conv(hidden_states)
+
+        return hidden_states
+
+    def forward(self, hidden_states, output_size=None):
+        if self.training and self.gradient_checkpointing:
+            # print("Upsample2D: gradient_checkpointing")
+
+            def create_custom_forward(func):
+                def custom_forward(*inputs):
+                    return func(*inputs)
+
+                return custom_forward
+
+            hidden_states = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
+            )
+        else:
+            hidden_states = self.forward_body(hidden_states, output_size)
+
+        return hidden_states
+
+
+class SdxlUNet2DConditionModel(nn.Module):
+    _supports_gradient_checkpointing = True
+
+    def __init__(
+        self,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.in_channels = IN_CHANNELS
+        self.out_channels = OUT_CHANNELS
+        self.model_channels = MODEL_CHANNELS
+        self.time_embed_dim = TIME_EMBED_DIM
+        self.adm_in_channels = ADM_IN_CHANNELS
+
+        self.gradient_checkpointing = False
+        # self.sample_size = sample_size
+
+        # time embedding
+        self.time_embed = nn.Sequential(
+            nn.Linear(self.model_channels, self.time_embed_dim),
+            nn.SiLU(),
+            nn.Linear(self.time_embed_dim, self.time_embed_dim),
+        )
+
+        # label embedding
+        self.label_emb = nn.Sequential(
+            nn.Sequential(
+                nn.Linear(self.adm_in_channels, self.time_embed_dim),
+                nn.SiLU(),
+                nn.Linear(self.time_embed_dim, self.time_embed_dim),
+            )
+        )
+
+        # input
+        self.input_blocks = nn.ModuleList(
+            [
+                nn.Sequential(
+                    nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
+                )
+            ]
+        )
+
+        # level 0
+        for i in range(2):
+            layers = [
+                ResnetBlock2D(
+                    in_channels=1 * self.model_channels,
+                    out_channels=1 * self.model_channels,
+                ),
+            ]
+            self.input_blocks.append(nn.ModuleList(layers))
+
+        self.input_blocks.append(
+            nn.Sequential(
+                Downsample2D(
+                    channels=1 * self.model_channels,
+                    out_channels=1 * self.model_channels,
+                ),
+            )
+        )
+
+        # level 1
+        for i in range(2):
+            layers = [
+                ResnetBlock2D(
+                    in_channels=(1 if i == 0 else 2) * self.model_channels,
+                    out_channels=2 * self.model_channels,
+                ),
+                Transformer2DModel(
+                    num_attention_heads=2 * self.model_channels // 64,
+                    attention_head_dim=64,
+                    in_channels=2 * self.model_channels,
+                    num_transformer_layers=2,
+                    use_linear_projection=True,
+                    cross_attention_dim=2048,
+                ),
+            ]
+            self.input_blocks.append(nn.ModuleList(layers))
+
+        self.input_blocks.append(
+            nn.Sequential(
+                Downsample2D(
+                    channels=2 * self.model_channels,
+                    out_channels=2 * self.model_channels,
+                ),
+            )
+        )
+
+        # level 2
+        for i in range(2):
+            layers = [
+                ResnetBlock2D(
+                    in_channels=(2 if i == 0 else 4) * self.model_channels,
+                    out_channels=4 * self.model_channels,
+                ),
+                Transformer2DModel(
+                    num_attention_heads=4 * self.model_channels // 64,
+                    attention_head_dim=64,
+                    in_channels=4 * self.model_channels,
+                    num_transformer_layers=10,
+                    use_linear_projection=True,
+                    cross_attention_dim=2048,
+                ),
+            ]
+            self.input_blocks.append(nn.ModuleList(layers))
+
+        # mid
+        self.middle_block = nn.ModuleList(
+            [
+                ResnetBlock2D(
+                    in_channels=4 * self.model_channels,
+                    out_channels=4 * self.model_channels,
+                ),
+                Transformer2DModel(
+                    num_attention_heads=4 * self.model_channels // 64,
+                    attention_head_dim=64,
+                    in_channels=4 * self.model_channels,
+                    num_transformer_layers=10,
+                    use_linear_projection=True,
+                    cross_attention_dim=2048,
+                ),
+                ResnetBlock2D(
+                    in_channels=4 * self.model_channels,
+                    out_channels=4 * self.model_channels,
+                ),
+            ]
+        )
+
+        # output
+        self.output_blocks = nn.ModuleList([])
+
+        # level 2
+        for i in range(3):
+            layers = [
+                ResnetBlock2D(
+                    in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
+                    out_channels=4 * self.model_channels,
+                ),
+                Transformer2DModel(
+                    num_attention_heads=4 * self.model_channels // 64,
+                    attention_head_dim=64,
+                    in_channels=4 * self.model_channels,
+                    num_transformer_layers=10,
+                    use_linear_projection=True,
+                    cross_attention_dim=2048,
+                ),
+            ]
+            if i == 2:
+                layers.append(
+                    Upsample2D(
+                        channels=4 * self.model_channels,
+                        out_channels=4 * self.model_channels,
+                    )
+                )
+
+            self.output_blocks.append(nn.ModuleList(layers))
+
+        # level 1
+        for i in range(3):
+            layers = [
+                ResnetBlock2D(
+                    in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
+                    out_channels=2 * self.model_channels,
+                ),
+                Transformer2DModel(
+                    num_attention_heads=2 * self.model_channels // 64,
+                    attention_head_dim=64,
+                    in_channels=2 * self.model_channels,
+                    num_transformer_layers=2,
+                    use_linear_projection=True,
+                    cross_attention_dim=2048,
+                ),
+            ]
+            if i == 2:
+                layers.append(
+                    Upsample2D(
+                        channels=2 * self.model_channels,
+                        out_channels=2 * self.model_channels,
+                    )
+                )
+
+            self.output_blocks.append(nn.ModuleList(layers))
+
+        # level 0
+        for i in range(3):
+            layers = [
+                ResnetBlock2D(
+                    in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
+                    out_channels=1 * self.model_channels,
+                ),
+            ]
+
+            self.output_blocks.append(nn.ModuleList(layers))
+
+        # output
+        self.out = nn.ModuleList(
+            [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
+        )
+
+    # region diffusers compatibility
+    def prepare_config(self):
+        self.config = SimpleNamespace()
+
+    @property
+    def dtype(self) -> torch.dtype:
+        # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+        return get_parameter_dtype(self)
+
+    @property
+    def device(self) -> torch.device:
+        # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
+        return get_parameter_device(self)
+
+    def set_attention_slice(self, slice_size):
+        raise NotImplementedError("Attention slicing is not supported for this model.")
+
+    def is_gradient_checkpointing(self) -> bool:
+        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+    def enable_gradient_checkpointing(self):
+        self.gradient_checkpointing = True
+        self.set_gradient_checkpointing(value=True)
+
+    def disable_gradient_checkpointing(self):
+        self.gradient_checkpointing = False
+        self.set_gradient_checkpointing(value=False)
+
+    def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
+        blocks = self.input_blocks + [self.middle_block] + self.output_blocks
+        for block in blocks:
+            for module in block:
+                if hasattr(module, "set_use_memory_efficient_attention"):
+                    # print(module.__class__.__name__)
+                    module.set_use_memory_efficient_attention(xformers, mem_eff)
+
+    def set_use_sdpa(self, sdpa: bool) -> None:
+        blocks = self.input_blocks + [self.middle_block] + self.output_blocks
+        for block in blocks:
+            for module in block:
+                if hasattr(module, "set_use_sdpa"):
+                    module.set_use_sdpa(sdpa)
+
+    def set_gradient_checkpointing(self, value=False):
+        blocks = self.input_blocks + [self.middle_block] + self.output_blocks
+        for block in blocks:
+            for module in block.modules():
+                if hasattr(module, "gradient_checkpointing"):
+                    # print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
+                    module.gradient_checkpointing = value
+
+    # endregion
+
+    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+        # broadcast timesteps to batch dimension
+        timesteps = timesteps.expand(x.shape[0])
+
+        hs = []
+        t_emb = get_timestep_embedding(timesteps, self.model_channels)  # , repeat_only=False)
+        t_emb = t_emb.to(x.dtype)
+        emb = self.time_embed(t_emb)
+
+        assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
+        assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
+        # assert x.dtype == self.dtype
+        emb = emb + self.label_emb(y)
+
+        def call_module(module, h, emb, context):
+            x = h
+            for layer in module:
+                # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
+                if isinstance(layer, ResnetBlock2D):
+                    x = layer(x, emb)
+                elif isinstance(layer, Transformer2DModel):
+                    x = layer(x, context)
+                else:
+                    x = layer(x)
+            return x
+
+        # h = x.type(self.dtype)
+        h = x
+
+        for module in self.input_blocks:
+            h = call_module(module, h, emb, context)
+            hs.append(h)
+
+        h = call_module(self.middle_block, h, emb, context)
+
+        for module in self.output_blocks:
+            h = torch.cat([h, hs.pop()], dim=1)
+            h = call_module(module, h, emb, context)
+
+        h = h.type(x.dtype)
+        h = call_module(self.out, h, emb, context)
+
+        return h
+
+
+class InferSdxlUNet2DConditionModel:
+    def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
+        self.delegate = original_unet
+
+        # override original model's forward method: because forward is not called by `__call__`
+        # overriding `__call__` is not enough, because nn.Module.forward has a special handling
+        self.delegate.forward = self.forward
+
+        # Deep Shrink
+        self.ds_depth_1 = None
+        self.ds_depth_2 = None
+        self.ds_timesteps_1 = None
+        self.ds_timesteps_2 = None
+        self.ds_ratio = None
+
+    # call original model's methods
+    def __getattr__(self, name):
+        return getattr(self.delegate, name)
+    
+    def __call__(self, *args, **kwargs):
+        return self.delegate(*args, **kwargs)
+
+    def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
+        if ds_depth_1 is None:
+            print("Deep Shrink is disabled.")
+            self.ds_depth_1 = None
+            self.ds_timesteps_1 = None
+            self.ds_depth_2 = None
+            self.ds_timesteps_2 = None
+            self.ds_ratio = None
+        else:
+            print(
+                f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
+            )
+            self.ds_depth_1 = ds_depth_1
+            self.ds_timesteps_1 = ds_timesteps_1
+            self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
+            self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
+            self.ds_ratio = ds_ratio
+
+    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+        r"""
+        current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
+        """
+        _self = self.delegate
+
+        # broadcast timesteps to batch dimension
+        timesteps = timesteps.expand(x.shape[0])
+
+        hs = []
+        t_emb = get_timestep_embedding(timesteps, _self.model_channels)  # , repeat_only=False)
+        t_emb = t_emb.to(x.dtype)
+        emb = _self.time_embed(t_emb)
+
+        assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
+        assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
+        # assert x.dtype == _self.dtype
+        emb = emb + _self.label_emb(y)
+
+        def call_module(module, h, emb, context):
+            x = h
+            for layer in module:
+                # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
+                if isinstance(layer, ResnetBlock2D):
+                    x = layer(x, emb)
+                elif isinstance(layer, Transformer2DModel):
+                    x = layer(x, context)
+                else:
+                    x = layer(x)
+            return x
+
+        # h = x.type(self.dtype)
+        h = x
+
+        for depth, module in enumerate(_self.input_blocks):
+            # Deep Shrink
+            if self.ds_depth_1 is not None:
+                if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
+                    self.ds_depth_2 is not None
+                    and depth == self.ds_depth_2
+                    and timesteps[0] < self.ds_timesteps_1
+                    and timesteps[0] >= self.ds_timesteps_2
+                ):
+                    # print("downsample", h.shape, self.ds_ratio)
+                    org_dtype = h.dtype
+                    if org_dtype == torch.bfloat16:
+                        h = h.to(torch.float32)
+                    h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
+
+            h = call_module(module, h, emb, context)
+            hs.append(h)
+
+        h = call_module(_self.middle_block, h, emb, context)
+
+        for module in _self.output_blocks:
+            # Deep Shrink
+            if self.ds_depth_1 is not None:
+                if hs[-1].shape[-2:] != h.shape[-2:]:
+                    # print("upsample", h.shape, hs[-1].shape)
+                    h = resize_like(h, hs[-1])
+
+            h = torch.cat([h, hs.pop()], dim=1)
+            h = call_module(module, h, emb, context)
+
+        # Deep Shrink: in case of depth 0
+        if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
+            # print("upsample", h.shape, x.shape)
+            h = resize_like(h, x)
+
+        h = h.type(x.dtype)
+        h = call_module(_self.out, h, emb, context)
+
+        return h
+
+
+if __name__ == "__main__":
+    import time
+
+    print("create unet")
+    unet = SdxlUNet2DConditionModel()
+
+    unet.to("cuda")
+    unet.set_use_memory_efficient_attention(True, False)
+    unet.set_gradient_checkpointing(True)
+    unet.train()
+
+    # 使用メモリ量確認用の疑似学習ループ
+    print("preparing optimizer")
+
+    # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
+
+    # import bitsandbytes
+    # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3)        # not working
+    # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3)  # working at 23.5 GB with torch2
+    # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3)  # working at 23.5 GB with torch2
+
+    import transformers
+
+    optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True)  # working at 22.2GB with torch2
+
+    scaler = torch.cuda.amp.GradScaler(enabled=True)
+
+    print("start training")
+    steps = 10
+    batch_size = 1
+
+    for step in range(steps):
+        print(f"step {step}")
+        if step == 1:
+            time_start = time.perf_counter()
+
+        x = torch.randn(batch_size, 4, 128, 128).cuda()  # 1024x1024
+        t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
+        ctx = torch.randn(batch_size, 77, 2048).cuda()
+        y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
+
+        with torch.cuda.amp.autocast(enabled=True):
+            output = unet(x, t, ctx, y)
+            target = torch.randn_like(output)
+            loss = torch.nn.functional.mse_loss(output, target)
+
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad(set_to_none=True)
+
+    time_end = time.perf_counter()
+    print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
diff --git a/external/llite/library/sdxl_train_util.py b/external/llite/library/sdxl_train_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce6bc4b45b822f5ec9df4eca6fa168bf1956164e
--- /dev/null
+++ b/external/llite/library/sdxl_train_util.py
@@ -0,0 +1,367 @@
+import argparse
+import gc
+import math
+import os
+from typing import Optional
+import torch
+from accelerate import init_empty_weights
+from tqdm import tqdm
+from transformers import CLIPTokenizer
+from external.llite.library import model_util, sdxl_model_util, train_util, sdxl_original_unet
+from external.llite.library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
+
+TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
+TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
+
+# DEFAULT_NOISE_OFFSET = 0.0357
+
+
+def load_target_model(args, accelerator, model_version: str, weight_dtype):
+    # load models for each process
+    model_dtype = match_mixed_precision(args, weight_dtype)  # prepare fp16/bf16
+    for pi in range(accelerator.state.num_processes):
+        if pi == accelerator.state.local_process_index:
+            print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
+
+            (
+                load_stable_diffusion_format,
+                text_encoder1,
+                text_encoder2,
+                vae,
+                unet,
+                logit_scale,
+                ckpt_info,
+            ) = _load_target_model(
+                args.pretrained_model_name_or_path,
+                args.vae,
+                model_version,
+                weight_dtype,
+                accelerator.device if args.lowram else "cpu",
+                model_dtype,
+            )
+
+            # work on low-ram device
+            if args.lowram:
+                text_encoder1.to(accelerator.device)
+                text_encoder2.to(accelerator.device)
+                unet.to(accelerator.device)
+                vae.to(accelerator.device)
+
+            gc.collect()
+            torch.cuda.empty_cache()
+        accelerator.wait_for_everyone()
+
+    return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
+
+
+def _load_target_model(
+    name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
+):
+    # model_dtype only work with full fp16/bf16
+    name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
+    load_stable_diffusion_format = os.path.isfile(name_or_path)  # determine SD or Diffusers
+
+    if load_stable_diffusion_format:
+        print(f"load StableDiffusion checkpoint: {name_or_path}")
+        (
+            text_encoder1,
+            text_encoder2,
+            vae,
+            unet,
+            logit_scale,
+            ckpt_info,
+        ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
+    else:
+        # Diffusers model is loaded to CPU
+        from diffusers import StableDiffusionXLPipeline
+
+        variant = "fp16" if weight_dtype == torch.float16 else None
+        print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
+        try:
+            try:
+                pipe = StableDiffusionXLPipeline.from_pretrained(
+                    name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
+                )
+            except EnvironmentError as ex:
+                if variant is not None:
+                    print("try to load fp32 model")
+                    pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
+                else:
+                    raise ex
+        except EnvironmentError as ex:
+            print(
+                f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
+            )
+            raise ex
+
+        text_encoder1 = pipe.text_encoder
+        text_encoder2 = pipe.text_encoder_2
+
+        # convert to fp32 for cache text_encoders outputs
+        if text_encoder1.dtype != torch.float32:
+            text_encoder1 = text_encoder1.to(dtype=torch.float32)
+        if text_encoder2.dtype != torch.float32:
+            text_encoder2 = text_encoder2.to(dtype=torch.float32)
+
+        vae = pipe.vae
+        unet = pipe.unet
+        del pipe
+
+        # Diffusers U-Net to original U-Net
+        state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
+        with init_empty_weights():
+            unet = sdxl_original_unet.SdxlUNet2DConditionModel()  # overwrite unet
+        sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
+        print("U-Net converted to original U-Net")
+
+        logit_scale = None
+        ckpt_info = None
+
+    # VAEを読み込む
+    if vae_path is not None:
+        vae = model_util.load_vae(vae_path, weight_dtype)
+        print("additional VAE loaded")
+
+    return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
+
+
+def load_tokenizers(args: argparse.Namespace):
+    print("prepare tokenizers")
+
+    original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
+    tokeniers = []
+    for i, original_path in enumerate(original_paths):
+        tokenizer: CLIPTokenizer = None
+        if args.tokenizer_cache_dir:
+            local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
+            if os.path.exists(local_tokenizer_path):
+                print(f"load tokenizer from cache: {local_tokenizer_path}")
+                tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
+
+        if tokenizer is None:
+            tokenizer = CLIPTokenizer.from_pretrained(original_path)
+
+        if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
+            print(f"save Tokenizer to cache: {local_tokenizer_path}")
+            tokenizer.save_pretrained(local_tokenizer_path)
+
+        if i == 1:
+            tokenizer.pad_token_id = 0  # fix pad token id to make same as open clip tokenizer
+
+        tokeniers.append(tokenizer)
+
+    if hasattr(args, "max_token_length") and args.max_token_length is not None:
+        print(f"update token length: {args.max_token_length}")
+
+    return tokeniers
+
+
+def match_mixed_precision(args, weight_dtype):
+    if args.full_fp16:
+        assert (
+            weight_dtype == torch.float16
+        ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
+        return weight_dtype
+    elif args.full_bf16:
+        assert (
+            weight_dtype == torch.bfloat16
+        ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
+        return weight_dtype
+    else:
+        return None
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    half = dim // 2
+    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+        device=timesteps.device
+    )
+    args = timesteps[:, None].float() * freqs[None]
+    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+    if dim % 2:
+        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    return embedding
+
+
+def get_timestep_embedding(x, outdim):
+    assert len(x.shape) == 2
+    b, dims = x.shape[0], x.shape[1]
+    x = torch.flatten(x)
+    emb = timestep_embedding(x, outdim)
+    emb = torch.reshape(emb, (b, dims * outdim))
+    return emb
+
+
+def get_size_embeddings(orig_size, crop_size, target_size, device):
+    emb1 = get_timestep_embedding(orig_size, 256)
+    emb2 = get_timestep_embedding(crop_size, 256)
+    emb3 = get_timestep_embedding(target_size, 256)
+    vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
+    return vector
+
+
+def save_sd_model_on_train_end(
+    args: argparse.Namespace,
+    src_path: str,
+    save_stable_diffusion_format: bool,
+    use_safetensors: bool,
+    save_dtype: torch.dtype,
+    epoch: int,
+    global_step: int,
+    text_encoder1,
+    text_encoder2,
+    unet,
+    vae,
+    logit_scale,
+    ckpt_info,
+):
+    def sd_saver(ckpt_file, epoch_no, global_step):
+        sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
+        sdxl_model_util.save_stable_diffusion_checkpoint(
+            ckpt_file,
+            text_encoder1,
+            text_encoder2,
+            unet,
+            epoch_no,
+            global_step,
+            ckpt_info,
+            vae,
+            logit_scale,
+            sai_metadata,
+            save_dtype,
+        )
+
+    def diffusers_saver(out_dir):
+        sdxl_model_util.save_diffusers_checkpoint(
+            out_dir,
+            text_encoder1,
+            text_encoder2,
+            unet,
+            src_path,
+            vae,
+            use_safetensors=use_safetensors,
+            save_dtype=save_dtype,
+        )
+
+    train_util.save_sd_model_on_train_end_common(
+        args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
+    )
+
+
+# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
+# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
+def save_sd_model_on_epoch_end_or_stepwise(
+    args: argparse.Namespace,
+    on_epoch_end: bool,
+    accelerator,
+    src_path,
+    save_stable_diffusion_format: bool,
+    use_safetensors: bool,
+    save_dtype: torch.dtype,
+    epoch: int,
+    num_train_epochs: int,
+    global_step: int,
+    text_encoder1,
+    text_encoder2,
+    unet,
+    vae,
+    logit_scale,
+    ckpt_info,
+):
+    def sd_saver(ckpt_file, epoch_no, global_step):
+        sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
+        sdxl_model_util.save_stable_diffusion_checkpoint(
+            ckpt_file,
+            text_encoder1,
+            text_encoder2,
+            unet,
+            epoch_no,
+            global_step,
+            ckpt_info,
+            vae,
+            logit_scale,
+            sai_metadata,
+            save_dtype,
+        )
+
+    def diffusers_saver(out_dir):
+        sdxl_model_util.save_diffusers_checkpoint(
+            out_dir,
+            text_encoder1,
+            text_encoder2,
+            unet,
+            src_path,
+            vae,
+            use_safetensors=use_safetensors,
+            save_dtype=save_dtype,
+        )
+
+    train_util.save_sd_model_on_epoch_end_or_stepwise_common(
+        args,
+        on_epoch_end,
+        accelerator,
+        save_stable_diffusion_format,
+        use_safetensors,
+        epoch,
+        num_train_epochs,
+        global_step,
+        sd_saver,
+        diffusers_saver,
+    )
+
+
+def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
+    )
+    parser.add_argument(
+        "--cache_text_encoder_outputs_to_disk",
+        action="store_true",
+        help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
+    )
+
+
+def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
+    assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
+    if args.v_parameterization:
+        print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
+
+    if args.clip_skip is not None:
+        print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
+
+    # if args.multires_noise_iterations:
+    #     print(
+    #         f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
+    #     )
+    # else:
+    #     if args.noise_offset is None:
+    #         args.noise_offset = DEFAULT_NOISE_OFFSET
+    #     elif args.noise_offset != DEFAULT_NOISE_OFFSET:
+    #         print(
+    #             f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
+    #         )
+    #     print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
+
+    assert (
+        not hasattr(args, "weighted_captions") or not args.weighted_captions
+    ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
+
+    if supportTextEncoderCaching:
+        if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
+            args.cache_text_encoder_outputs = True
+            print(
+                "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+                + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
+            )
+
+
+def sample_images(*args, **kwargs):
+    return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
diff --git a/external/llite/library/slicing_vae.py b/external/llite/library/slicing_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c4e056d39679f67d408a42ec2c92bed1b691c80
--- /dev/null
+++ b/external/llite/library/slicing_vae.py
@@ -0,0 +1,679 @@
+# Modified from Diffusers to reduce VRAM usage
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
+from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
+from diffusers.models.autoencoder_kl import AutoencoderKLOutput
+
+
+def slice_h(x, num_slices):
+    # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
+    # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
+    # NCHWでもNHWCでもどちらでも動く
+    size = (x.shape[2] + num_slices - 1) // num_slices
+    sliced = []
+    for i in range(num_slices):
+        if i == 0:
+            sliced.append(x[:, :, : size + 1, :])
+        else:
+            end = size * (i + 1) + 1
+            if x.shape[2] - end < 3:  # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
+                end = x.shape[2]
+            sliced.append(x[:, :, size * i - 1 : end, :])
+            if end >= x.shape[2]:
+                break
+    return sliced
+
+
+def cat_h(sliced):
+    # padding分を除いて結合する
+    cat = []
+    for i, x in enumerate(sliced):
+        if i == 0:
+            cat.append(x[:, :, :-1, :])
+        elif i == len(sliced) - 1:
+            cat.append(x[:, :, 1:, :])
+        else:
+            cat.append(x[:, :, 1:-1, :])
+        del x
+    x = torch.cat(cat, dim=2)
+    return x
+
+
+def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
+    assert _self.upsample is None and _self.downsample is None
+    assert _self.norm1.num_groups == _self.norm2.num_groups
+    assert temb is None
+
+    # make sure norms are on cpu
+    org_device = input_tensor.device
+    cpu_device = torch.device("cpu")
+    _self.norm1.to(cpu_device)
+    _self.norm2.to(cpu_device)
+
+    # GroupNormがCPUでfp16で動かない対策
+    org_dtype = input_tensor.dtype
+    if org_dtype == torch.float16:
+        _self.norm1.to(torch.float32)
+        _self.norm2.to(torch.float32)
+
+    # すべてのテンソルをCPUに移動する
+    input_tensor = input_tensor.to(cpu_device)
+    hidden_states = input_tensor
+
+    # どうもこれは結果が異なるようだ……
+    # def sliced_norm1(norm, x):
+    #     num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
+    #     sliced_tensor = torch.chunk(x, num_div, dim=1)
+    #     sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
+    #     sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
+    #     print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
+    #     normed_tensor = []
+    #     for i in range(num_div):
+    #         n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
+    #         normed_tensor.append(n)
+    #         del n
+    #     x = torch.cat(normed_tensor, dim=1)
+    #     return num_div, x
+
+    # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
+    if org_dtype == torch.float16:
+        hidden_states = hidden_states.to(torch.float32)
+    hidden_states = _self.norm1(hidden_states)  # run on cpu
+    if org_dtype == torch.float16:
+        hidden_states = hidden_states.to(torch.float16)
+
+    sliced = slice_h(hidden_states, num_slices)
+    del hidden_states
+
+    for i in range(len(sliced)):
+        x = sliced[i]
+        sliced[i] = None
+
+        # 計算する部分だけGPUに移動する、以下同様
+        x = x.to(org_device)
+        x = _self.nonlinearity(x)
+        x = _self.conv1(x)
+        x = x.to(cpu_device)
+        sliced[i] = x
+        del x
+
+    hidden_states = cat_h(sliced)
+    del sliced
+
+    if org_dtype == torch.float16:
+        hidden_states = hidden_states.to(torch.float32)
+    hidden_states = _self.norm2(hidden_states)  # run on cpu
+    if org_dtype == torch.float16:
+        hidden_states = hidden_states.to(torch.float16)
+
+    sliced = slice_h(hidden_states, num_slices)
+    del hidden_states
+
+    for i in range(len(sliced)):
+        x = sliced[i]
+        sliced[i] = None
+
+        x = x.to(org_device)
+        x = _self.nonlinearity(x)
+        x = _self.dropout(x)
+        x = _self.conv2(x)
+        x = x.to(cpu_device)
+        sliced[i] = x
+        del x
+
+    hidden_states = cat_h(sliced)
+    del sliced
+
+    # make shortcut
+    if _self.conv_shortcut is not None:
+        sliced = list(torch.chunk(input_tensor, num_slices, dim=2))  # no padding in conv_shortcut パディングがないので普通にスライスする
+        del input_tensor
+
+        for i in range(len(sliced)):
+            x = sliced[i]
+            sliced[i] = None
+
+            x = x.to(org_device)
+            x = _self.conv_shortcut(x)
+            x = x.to(cpu_device)
+            sliced[i] = x
+            del x
+
+        input_tensor = torch.cat(sliced, dim=2)
+        del sliced
+
+    output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
+
+    output_tensor = output_tensor.to(org_device)  # 次のレイヤーがGPUで計算する
+    return output_tensor
+
+
+class SlicingEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels=3,
+        out_channels=3,
+        down_block_types=("DownEncoderBlock2D",),
+        block_out_channels=(64,),
+        layers_per_block=2,
+        norm_num_groups=32,
+        act_fn="silu",
+        double_z=True,
+        num_slices=2,
+    ):
+        super().__init__()
+        self.layers_per_block = layers_per_block
+
+        self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+        self.mid_block = None
+        self.down_blocks = nn.ModuleList([])
+
+        # down
+        output_channel = block_out_channels[0]
+        for i, down_block_type in enumerate(down_block_types):
+            input_channel = output_channel
+            output_channel = block_out_channels[i]
+            is_final_block = i == len(block_out_channels) - 1
+
+            down_block = get_down_block(
+                down_block_type,
+                num_layers=self.layers_per_block,
+                in_channels=input_channel,
+                out_channels=output_channel,
+                add_downsample=not is_final_block,
+                resnet_eps=1e-6,
+                downsample_padding=0,
+                resnet_act_fn=act_fn,
+                resnet_groups=norm_num_groups,
+                attention_head_dim=output_channel,
+                temb_channels=None,
+            )
+            self.down_blocks.append(down_block)
+
+        # mid
+        self.mid_block = UNetMidBlock2D(
+            in_channels=block_out_channels[-1],
+            resnet_eps=1e-6,
+            resnet_act_fn=act_fn,
+            output_scale_factor=1,
+            resnet_time_scale_shift="default",
+            attention_head_dim=block_out_channels[-1],
+            resnet_groups=norm_num_groups,
+            temb_channels=None,
+        )
+        self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)  # とりあえずDiffusersのxformersを使う
+
+        # out
+        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+        self.conv_act = nn.SiLU()
+
+        conv_out_channels = 2 * out_channels if double_z else out_channels
+        self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+        # replace forward of ResBlocks
+        def wrapper(func, module, num_slices):
+            def forward(*args, **kwargs):
+                return func(module, num_slices, *args, **kwargs)
+
+            return forward
+
+        self.num_slices = num_slices
+        div = num_slices / (2 ** (len(self.down_blocks) - 1))  # 深い層はそこまで分割しなくていいので適宜減らす
+        # print(f"initial divisor: {div}")
+        if div >= 2:
+            div = int(div)
+            for resnet in self.mid_block.resnets:
+                resnet.forward = wrapper(resblock_forward, resnet, div)
+            # midblock doesn't have downsample
+
+        for i, down_block in enumerate(self.down_blocks[::-1]):
+            if div >= 2:
+                div = int(div)
+                # print(f"down block: {i} divisor: {div}")
+                for resnet in down_block.resnets:
+                    resnet.forward = wrapper(resblock_forward, resnet, div)
+                if down_block.downsamplers is not None:
+                    # print("has downsample")
+                    for downsample in down_block.downsamplers:
+                        downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
+            div *= 2
+
+    def forward(self, x):
+        sample = x
+        del x
+
+        org_device = sample.device
+        cpu_device = torch.device("cpu")
+
+        # sample = self.conv_in(sample)
+        sample = sample.to(cpu_device)
+        sliced = slice_h(sample, self.num_slices)
+        del sample
+
+        for i in range(len(sliced)):
+            x = sliced[i]
+            sliced[i] = None
+
+            x = x.to(org_device)
+            x = self.conv_in(x)
+            x = x.to(cpu_device)
+            sliced[i] = x
+            del x
+
+        sample = cat_h(sliced)
+        del sliced
+
+        sample = sample.to(org_device)
+
+        # down
+        for down_block in self.down_blocks:
+            sample = down_block(sample)
+
+        # middle
+        sample = self.mid_block(sample)
+
+        # post-process
+        # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
+        sample = self.conv_norm_out(sample)
+        sample = self.conv_act(sample)
+        sample = self.conv_out(sample)
+
+        return sample
+
+    def downsample_forward(self, _self, num_slices, hidden_states):
+        assert hidden_states.shape[1] == _self.channels
+        assert _self.use_conv and _self.padding == 0
+        print("downsample forward", num_slices, hidden_states.shape)
+
+        org_device = hidden_states.device
+        cpu_device = torch.device("cpu")
+
+        hidden_states = hidden_states.to(cpu_device)
+        pad = (0, 1, 0, 1)
+        hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
+
+        # slice with even number because of stride 2
+        # strideが2なので偶数でスライスする
+        # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
+        size = (hidden_states.shape[2] + num_slices - 1) // num_slices
+        size = size + 1 if size % 2 == 1 else size
+
+        sliced = []
+        for i in range(num_slices):
+            if i == 0:
+                sliced.append(hidden_states[:, :, : size + 1, :])
+            else:
+                end = size * (i + 1) + 1
+                if hidden_states.shape[2] - end < 4:  # if the last slice is too small, use the rest of the tensor
+                    end = hidden_states.shape[2]
+                sliced.append(hidden_states[:, :, size * i - 1 : end, :])
+                if end >= hidden_states.shape[2]:
+                    break
+        del hidden_states
+
+        for i in range(len(sliced)):
+            x = sliced[i]
+            sliced[i] = None
+
+            x = x.to(org_device)
+            x = _self.conv(x)
+            x = x.to(cpu_device)
+
+            # ここだけ雰囲気が違うのはCopilotのせい
+            if i == 0:
+                hidden_states = x
+            else:
+                hidden_states = torch.cat([hidden_states, x], dim=2)
+
+        hidden_states = hidden_states.to(org_device)
+        # print("downsample forward done", hidden_states.shape)
+        return hidden_states
+
+
+class SlicingDecoder(nn.Module):
+    def __init__(
+        self,
+        in_channels=3,
+        out_channels=3,
+        up_block_types=("UpDecoderBlock2D",),
+        block_out_channels=(64,),
+        layers_per_block=2,
+        norm_num_groups=32,
+        act_fn="silu",
+        num_slices=2,
+    ):
+        super().__init__()
+        self.layers_per_block = layers_per_block
+
+        self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+
+        self.mid_block = None
+        self.up_blocks = nn.ModuleList([])
+
+        # mid
+        self.mid_block = UNetMidBlock2D(
+            in_channels=block_out_channels[-1],
+            resnet_eps=1e-6,
+            resnet_act_fn=act_fn,
+            output_scale_factor=1,
+            resnet_time_scale_shift="default",
+            attention_head_dim=block_out_channels[-1],
+            resnet_groups=norm_num_groups,
+            temb_channels=None,
+        )
+        self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)  # とりあえずDiffusersのxformersを使う
+
+        # up
+        reversed_block_out_channels = list(reversed(block_out_channels))
+        output_channel = reversed_block_out_channels[0]
+        for i, up_block_type in enumerate(up_block_types):
+            prev_output_channel = output_channel
+            output_channel = reversed_block_out_channels[i]
+
+            is_final_block = i == len(block_out_channels) - 1
+
+            up_block = get_up_block(
+                up_block_type,
+                num_layers=self.layers_per_block + 1,
+                in_channels=prev_output_channel,
+                out_channels=output_channel,
+                prev_output_channel=None,
+                add_upsample=not is_final_block,
+                resnet_eps=1e-6,
+                resnet_act_fn=act_fn,
+                resnet_groups=norm_num_groups,
+                attention_head_dim=output_channel,
+                temb_channels=None,
+            )
+            self.up_blocks.append(up_block)
+            prev_output_channel = output_channel
+
+        # out
+        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+        self.conv_act = nn.SiLU()
+        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+        # replace forward of ResBlocks
+        def wrapper(func, module, num_slices):
+            def forward(*args, **kwargs):
+                return func(module, num_slices, *args, **kwargs)
+
+            return forward
+
+        self.num_slices = num_slices
+        div = num_slices / (2 ** (len(self.up_blocks) - 1))
+        print(f"initial divisor: {div}")
+        if div >= 2:
+            div = int(div)
+            for resnet in self.mid_block.resnets:
+                resnet.forward = wrapper(resblock_forward, resnet, div)
+            # midblock doesn't have upsample
+
+        for i, up_block in enumerate(self.up_blocks):
+            if div >= 2:
+                div = int(div)
+                # print(f"up block: {i} divisor: {div}")
+                for resnet in up_block.resnets:
+                    resnet.forward = wrapper(resblock_forward, resnet, div)
+                if up_block.upsamplers is not None:
+                    # print("has upsample")
+                    for upsample in up_block.upsamplers:
+                        upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
+            div *= 2
+
+    def forward(self, z):
+        sample = z
+        del z
+        sample = self.conv_in(sample)
+
+        # middle
+        sample = self.mid_block(sample)
+
+        # up
+        for i, up_block in enumerate(self.up_blocks):
+            sample = up_block(sample)
+
+        # post-process
+        sample = self.conv_norm_out(sample)
+        sample = self.conv_act(sample)
+
+        # conv_out with slicing because of VRAM usage
+        # conv_outはとてもVRAM使うのでスライスして対応
+        org_device = sample.device
+        cpu_device = torch.device("cpu")
+        sample = sample.to(cpu_device)
+
+        sliced = slice_h(sample, self.num_slices)
+        del sample
+        for i in range(len(sliced)):
+            x = sliced[i]
+            sliced[i] = None
+
+            x = x.to(org_device)
+            x = self.conv_out(x)
+            x = x.to(cpu_device)
+            sliced[i] = x
+        sample = cat_h(sliced)
+        del sliced
+
+        sample = sample.to(org_device)
+        return sample
+
+    def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
+        assert hidden_states.shape[1] == _self.channels
+        assert _self.use_conv_transpose == False and _self.use_conv
+
+        org_dtype = hidden_states.dtype
+        org_device = hidden_states.device
+        cpu_device = torch.device("cpu")
+
+        hidden_states = hidden_states.to(cpu_device)
+        sliced = slice_h(hidden_states, num_slices)
+        del hidden_states
+
+        for i in range(len(sliced)):
+            x = sliced[i]
+            sliced[i] = None
+
+            x = x.to(org_device)
+
+            # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+            # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+            # https://github.com/pytorch/pytorch/issues/86679
+            # PyTorch 2で直らないかね……
+            if org_dtype == torch.bfloat16:
+                x = x.to(torch.float32)
+
+            x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+
+            if org_dtype == torch.bfloat16:
+                x = x.to(org_dtype)
+
+            x = _self.conv(x)
+
+            # upsampleされてるのでpadは2になる
+            if i == 0:
+                x = x[:, :, :-2, :]
+            elif i == num_slices - 1:
+                x = x[:, :, 2:, :]
+            else:
+                x = x[:, :, 2:-2, :]
+
+            x = x.to(cpu_device)
+            sliced[i] = x
+            del x
+
+        hidden_states = torch.cat(sliced, dim=2)
+        # print("us hidden_states", hidden_states.shape)
+        del sliced
+
+        hidden_states = hidden_states.to(org_device)
+        return hidden_states
+
+
+class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
+    r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+    and Max Welling.
+
+    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+    implements for all the model (such as downloading or saving, etc.)
+
+    Parameters:
+        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
+        down_block_types (`Tuple[str]`, *optional*, defaults to :
+            obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+        up_block_types (`Tuple[str]`, *optional*, defaults to :
+            obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+        block_out_channels (`Tuple[int]`, *optional*, defaults to :
+            obj:`(64,)`): Tuple of block output channels.
+        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+        latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
+        sample_size (`int`, *optional*, defaults to `32`): TODO
+    """
+
+    @register_to_config
+    def __init__(
+        self,
+        in_channels: int = 3,
+        out_channels: int = 3,
+        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+        up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+        block_out_channels: Tuple[int] = (64,),
+        layers_per_block: int = 1,
+        act_fn: str = "silu",
+        latent_channels: int = 4,
+        norm_num_groups: int = 32,
+        sample_size: int = 32,
+        num_slices: int = 16,
+    ):
+        super().__init__()
+
+        # pass init params to Encoder
+        self.encoder = SlicingEncoder(
+            in_channels=in_channels,
+            out_channels=latent_channels,
+            down_block_types=down_block_types,
+            block_out_channels=block_out_channels,
+            layers_per_block=layers_per_block,
+            act_fn=act_fn,
+            norm_num_groups=norm_num_groups,
+            double_z=True,
+            num_slices=num_slices,
+        )
+
+        # pass init params to Decoder
+        self.decoder = SlicingDecoder(
+            in_channels=latent_channels,
+            out_channels=out_channels,
+            up_block_types=up_block_types,
+            block_out_channels=block_out_channels,
+            layers_per_block=layers_per_block,
+            norm_num_groups=norm_num_groups,
+            act_fn=act_fn,
+            num_slices=num_slices,
+        )
+
+        self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+        self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+        self.use_slicing = False
+
+    def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+
+        if not return_dict:
+            return (posterior,)
+
+        return AutoencoderKLOutput(latent_dist=posterior)
+
+    def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+
+        if not return_dict:
+            return (dec,)
+
+        return DecoderOutput(sample=dec)
+
+    # これはバッチ方向のスライシング 紛らわしい
+    def enable_slicing(self):
+        r"""
+        Enable sliced VAE decoding.
+
+        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+        steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.use_slicing = True
+
+    def disable_slicing(self):
+        r"""
+        Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
+        decoding in one step.
+        """
+        self.use_slicing = False
+
+    def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+        if self.use_slicing and z.shape[0] > 1:
+            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+            decoded = torch.cat(decoded_slices)
+        else:
+            decoded = self._decode(z).sample
+
+        if not return_dict:
+            return (decoded,)
+
+        return DecoderOutput(sample=decoded)
+
+    def forward(
+        self,
+        sample: torch.FloatTensor,
+        sample_posterior: bool = False,
+        return_dict: bool = True,
+        generator: Optional[torch.Generator] = None,
+    ) -> Union[DecoderOutput, torch.FloatTensor]:
+        r"""
+        Args:
+            sample (`torch.FloatTensor`): Input sample.
+            sample_posterior (`bool`, *optional*, defaults to `False`):
+                Whether to sample from the posterior.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+        """
+        x = sample
+        posterior = self.encode(x).latent_dist
+        if sample_posterior:
+            z = posterior.sample(generator=generator)
+        else:
+            z = posterior.mode()
+        dec = self.decode(z).sample
+
+        if not return_dict:
+            return (dec,)
+
+        return DecoderOutput(sample=dec)
diff --git a/external/llite/library/train_util.py b/external/llite/library/train_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e7879964e0654ffcfaa671f51c8449bb3536460
--- /dev/null
+++ b/external/llite/library/train_util.py
@@ -0,0 +1,4856 @@
+# common functions for training
+
+import argparse
+import ast
+import asyncio
+import datetime
+import importlib
+import json
+import pathlib
+import re
+import shutil
+import time
+from typing import (
+    Dict,
+    List,
+    NamedTuple,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
+from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
+import gc
+import glob
+import math
+import os
+import random
+import hashlib
+import subprocess
+from io import BytesIO
+import toml
+
+from tqdm import tqdm
+import torch
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torchvision import transforms
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
+import transformers
+from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
+from diffusers import (
+    StableDiffusionPipeline,
+    DDPMScheduler,
+    EulerAncestralDiscreteScheduler,
+    DPMSolverMultistepScheduler,
+    DPMSolverSinglestepScheduler,
+    LMSDiscreteScheduler,
+    PNDMScheduler,
+    DDIMScheduler,
+    EulerDiscreteScheduler,
+    HeunDiscreteScheduler,
+    KDPM2DiscreteScheduler,
+    KDPM2AncestralDiscreteScheduler,
+    AutoencoderKL,
+)
+from external.llite.library import custom_train_functions
+from external.llite.library.original_unet import UNet2DConditionModel
+from huggingface_hub import hf_hub_download
+import numpy as np
+from PIL import Image
+import cv2
+import safetensors.torch
+from external.llite.library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
+import external.llite.library.model_util as model_util
+import external.llite.library.huggingface_util as huggingface_util
+import external.llite.library.sai_model_spec as sai_model_spec
+
+# from library.attention_processors import FlashAttnProcessor
+# from library.hypernetwork import replace_attentions_for_hypernetwork
+from external.llite.library.original_unet import UNet2DConditionModel
+
+# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
+TOKENIZER_PATH = "openai/clip-vit-large-patch14"
+V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2"  # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
+
+# checkpointファイル名
+EPOCH_STATE_NAME = "{}-{:06d}-state"
+EPOCH_FILE_NAME = "{}-{:06d}"
+EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
+LAST_STATE_NAME = "{}-state"
+DEFAULT_EPOCH_NAME = "epoch"
+DEFAULT_LAST_OUTPUT_NAME = "last"
+
+DEFAULT_STEP_NAME = "at"
+STEP_STATE_NAME = "{}-step{:08d}-state"
+STEP_FILE_NAME = "{}-step{:08d}"
+STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
+
+# region dataset
+
+IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
+
+try:
+    import pillow_avif
+
+    IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
+except:
+    pass
+
+# JPEG-XL on Linux
+try:
+    from jxlpy import JXLImagePlugin
+
+    IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
+except:
+    pass
+
+# JPEG-XL on Windows
+try:
+    import pillow_jxl
+
+    IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
+except:
+    pass
+
+IMAGE_TRANSFORMS = transforms.Compose(
+    [
+        transforms.ToTensor(),
+        transforms.Normalize([0.5], [0.5]),
+    ]
+)
+
+TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
+
+
+class ImageInfo:
+    def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
+        self.image_key: str = image_key
+        self.num_repeats: int = num_repeats
+        self.caption: str = caption
+        self.is_reg: bool = is_reg
+        self.absolute_path: str = absolute_path
+        self.image_size: Tuple[int, int] = None
+        self.resized_size: Tuple[int, int] = None
+        self.bucket_reso: Tuple[int, int] = None
+        self.latents: torch.Tensor = None
+        self.latents_flipped: torch.Tensor = None
+        self.latents_npz: str = None
+        self.latents_original_size: Tuple[int, int] = None  # original image size, not latents size
+        self.latents_crop_ltrb: Tuple[int, int] = None  # crop left top right bottom in original pixel size, not latents size
+        self.cond_img_path: str = None
+        self.image: Optional[Image.Image] = None  # optional, original PIL Image
+        # SDXL, optional
+        self.text_encoder_outputs_npz: Optional[str] = None
+        self.text_encoder_outputs1: Optional[torch.Tensor] = None
+        self.text_encoder_outputs2: Optional[torch.Tensor] = None
+        self.text_encoder_pool2: Optional[torch.Tensor] = None
+
+
+class BucketManager:
+    def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
+        if max_size is not None:
+            if max_reso is not None:
+                assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso"
+                assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso"
+            if min_size is not None:
+                assert max_size >= min_size, "the max_size should be larger than the min_size"
+
+        self.no_upscale = no_upscale
+        if max_reso is None:
+            self.max_reso = None
+            self.max_area = None
+        else:
+            self.max_reso = max_reso
+            self.max_area = max_reso[0] * max_reso[1]
+        self.min_size = min_size
+        self.max_size = max_size
+        self.reso_steps = reso_steps
+
+        self.resos = []
+        self.reso_to_id = {}
+        self.buckets = []  # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key
+
+    def add_image(self, reso, image_or_info):
+        bucket_id = self.reso_to_id[reso]
+        self.buckets[bucket_id].append(image_or_info)
+
+    def shuffle(self):
+        for bucket in self.buckets:
+            random.shuffle(bucket)
+
+    def sort(self):
+        # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
+        sorted_resos = self.resos.copy()
+        sorted_resos.sort()
+
+        sorted_buckets = []
+        sorted_reso_to_id = {}
+        for i, reso in enumerate(sorted_resos):
+            bucket_id = self.reso_to_id[reso]
+            sorted_buckets.append(self.buckets[bucket_id])
+            sorted_reso_to_id[reso] = i
+
+        self.resos = sorted_resos
+        self.buckets = sorted_buckets
+        self.reso_to_id = sorted_reso_to_id
+
+    def make_buckets(self):
+        resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
+        self.set_predefined_resos(resos)
+
+    def set_predefined_resos(self, resos):
+        # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
+        self.predefined_resos = resos.copy()
+        self.predefined_resos_set = set(resos)
+        self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
+
+    def add_if_new_reso(self, reso):
+        if reso not in self.reso_to_id:
+            bucket_id = len(self.resos)
+            self.reso_to_id[reso] = bucket_id
+            self.resos.append(reso)
+            self.buckets.append([])
+            # print(reso, bucket_id, len(self.buckets))
+
+    def round_to_steps(self, x):
+        x = int(x + 0.5)
+        return x - x % self.reso_steps
+
+    def select_bucket(self, image_width, image_height):
+        aspect_ratio = image_width / image_height
+        if not self.no_upscale:
+            # 拡大および縮小を行う
+            # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
+            reso = (image_width, image_height)
+            if reso in self.predefined_resos_set:
+                pass
+            else:
+                ar_errors = self.predefined_aspect_ratios - aspect_ratio
+                predefined_bucket_id = np.abs(ar_errors).argmin()  # 当該解像度以外でaspect ratio errorが最も少ないもの
+                reso = self.predefined_resos[predefined_bucket_id]
+
+            ar_reso = reso[0] / reso[1]
+            if aspect_ratio > ar_reso:  # 横が長い→縦を合わせる
+                scale = reso[1] / image_height
+            else:
+                scale = reso[0] / image_width
+
+            resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
+            # print("use predef", image_width, image_height, reso, resized_size)
+        else:
+            # 縮小のみを行う
+            if image_width * image_height > self.max_area:
+                # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
+                resized_width = math.sqrt(self.max_area * aspect_ratio)
+                resized_height = self.max_area / resized_width
+                assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
+
+                # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
+                # 元のbucketingと同じロジック
+                b_width_rounded = self.round_to_steps(resized_width)
+                b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
+                ar_width_rounded = b_width_rounded / b_height_in_wr
+
+                b_height_rounded = self.round_to_steps(resized_height)
+                b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
+                ar_height_rounded = b_width_in_hr / b_height_rounded
+
+                # print(b_width_rounded, b_height_in_wr, ar_width_rounded)
+                # print(b_width_in_hr, b_height_rounded, ar_height_rounded)
+
+                if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
+                    resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5))
+                else:
+                    resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded)
+                # print(resized_size)
+            else:
+                resized_size = (image_width, image_height)  # リサイズは不要
+
+            # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
+            bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
+            bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
+            # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
+
+            reso = (bucket_width, bucket_height)
+
+        self.add_if_new_reso(reso)
+
+        ar_error = (reso[0] / reso[1]) - aspect_ratio
+        return reso, resized_size, ar_error
+
+    @staticmethod
+    def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
+        # Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める
+        # Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation.
+
+        bucket_ar = bucket_reso[0] / bucket_reso[1]
+        image_ar = image_size[0] / image_size[1]
+        if bucket_ar > image_ar:
+            # bucketのほうが横長→縦を合わせる
+            resized_width = bucket_reso[1] * image_ar
+            resized_height = bucket_reso[1]
+        else:
+            resized_width = bucket_reso[0]
+            resized_height = bucket_reso[0] / image_ar
+        crop_left = (bucket_reso[0] - resized_width) // 2
+        crop_top = (bucket_reso[1] - resized_height) // 2
+        crop_right = crop_left + resized_width
+        crop_bottom = crop_top + resized_height
+        return crop_left, crop_top, crop_right, crop_bottom
+
+
+class BucketBatchIndex(NamedTuple):
+    bucket_index: int
+    bucket_batch_size: int
+    batch_index: int
+
+
+class AugHelper:
+    # albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる
+
+    def __init__(self):
+        pass
+
+    def color_aug(self, image: np.ndarray):
+        # self.color_aug_method = albu.OneOf(
+        #     [
+        #         albu.HueSaturationValue(8, 0, 0, p=0.5),
+        #         albu.RandomGamma((95, 105), p=0.5),
+        #     ],
+        #     p=0.33,
+        # )
+        hue_shift_limit = 8
+
+        # remove dependency to albumentations
+        if random.random() <= 0.33:
+            if random.random() > 0.5:
+                # hue shift
+                hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+                hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
+                if hue_shift < 0:
+                    hue_shift = 180 + hue_shift
+                hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
+                image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
+            else:
+                # random gamma
+                gamma = random.uniform(0.95, 1.05)
+                image = np.clip(image**gamma, 0, 255).astype(np.uint8)
+
+        return {"image": image}
+
+    def get_augmentor(self, use_color_aug: bool):  # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]:
+        return self.color_aug if use_color_aug else None
+
+
+class BaseSubset:
+    def __init__(
+        self,
+        image_dir: Optional[str],
+        num_repeats: int,
+        shuffle_caption: bool,
+        caption_separator: str,
+        keep_tokens: int,
+        keep_tokens_separator: str,
+        color_aug: bool,
+        flip_aug: bool,
+        face_crop_aug_range: Optional[Tuple[float, float]],
+        random_crop: bool,
+        caption_dropout_rate: float,
+        caption_dropout_every_n_epochs: int,
+        caption_tag_dropout_rate: float,
+        caption_prefix: Optional[str],
+        caption_suffix: Optional[str],
+        token_warmup_min: int,
+        token_warmup_step: Union[float, int],
+    ) -> None:
+        self.image_dir = image_dir
+        self.num_repeats = num_repeats
+        self.shuffle_caption = shuffle_caption
+        self.caption_separator = caption_separator
+        self.keep_tokens = keep_tokens
+        self.keep_tokens_separator = keep_tokens_separator
+        self.color_aug = color_aug
+        self.flip_aug = flip_aug
+        self.face_crop_aug_range = face_crop_aug_range
+        self.random_crop = random_crop
+        self.caption_dropout_rate = caption_dropout_rate
+        self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
+        self.caption_tag_dropout_rate = caption_tag_dropout_rate
+        self.caption_prefix = caption_prefix
+        self.caption_suffix = caption_suffix
+
+        self.token_warmup_min = token_warmup_min  # step=0におけるタグの数
+        self.token_warmup_step = token_warmup_step  # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
+
+        self.img_count = 0
+
+
+class DreamBoothSubset(BaseSubset):
+    def __init__(
+        self,
+        image_dir: str,
+        is_reg: bool,
+        class_tokens: Optional[str],
+        caption_extension: str,
+        num_repeats,
+        shuffle_caption,
+        caption_separator: str,
+        keep_tokens,
+        keep_tokens_separator,
+        color_aug,
+        flip_aug,
+        face_crop_aug_range,
+        random_crop,
+        caption_dropout_rate,
+        caption_dropout_every_n_epochs,
+        caption_tag_dropout_rate,
+        caption_prefix,
+        caption_suffix,
+        token_warmup_min,
+        token_warmup_step,
+    ) -> None:
+        assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
+
+        super().__init__(
+            image_dir,
+            num_repeats,
+            shuffle_caption,
+            caption_separator,
+            keep_tokens,
+            keep_tokens_separator,
+            color_aug,
+            flip_aug,
+            face_crop_aug_range,
+            random_crop,
+            caption_dropout_rate,
+            caption_dropout_every_n_epochs,
+            caption_tag_dropout_rate,
+            caption_prefix,
+            caption_suffix,
+            token_warmup_min,
+            token_warmup_step,
+        )
+
+        self.is_reg = is_reg
+        self.class_tokens = class_tokens
+        self.caption_extension = caption_extension
+        if self.caption_extension and not self.caption_extension.startswith("."):
+            self.caption_extension = "." + self.caption_extension
+
+    def __eq__(self, other) -> bool:
+        if not isinstance(other, DreamBoothSubset):
+            return NotImplemented
+        return self.image_dir == other.image_dir
+
+
+class FineTuningSubset(BaseSubset):
+    def __init__(
+        self,
+        image_dir,
+        metadata_file: str,
+        num_repeats,
+        shuffle_caption,
+        caption_separator,
+        keep_tokens,
+        keep_tokens_separator,
+        color_aug,
+        flip_aug,
+        face_crop_aug_range,
+        random_crop,
+        caption_dropout_rate,
+        caption_dropout_every_n_epochs,
+        caption_tag_dropout_rate,
+        caption_prefix,
+        caption_suffix,
+        token_warmup_min,
+        token_warmup_step,
+    ) -> None:
+        assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
+
+        super().__init__(
+            image_dir,
+            num_repeats,
+            shuffle_caption,
+            caption_separator,
+            keep_tokens,
+            keep_tokens_separator,
+            color_aug,
+            flip_aug,
+            face_crop_aug_range,
+            random_crop,
+            caption_dropout_rate,
+            caption_dropout_every_n_epochs,
+            caption_tag_dropout_rate,
+            caption_prefix,
+            caption_suffix,
+            token_warmup_min,
+            token_warmup_step,
+        )
+
+        self.metadata_file = metadata_file
+
+    def __eq__(self, other) -> bool:
+        if not isinstance(other, FineTuningSubset):
+            return NotImplemented
+        return self.metadata_file == other.metadata_file
+
+
+class ControlNetSubset(BaseSubset):
+    def __init__(
+        self,
+        image_dir: str,
+        conditioning_data_dir: str,
+        caption_extension: str,
+        num_repeats,
+        shuffle_caption,
+        caption_separator,
+        keep_tokens,
+        keep_tokens_separator,
+        color_aug,
+        flip_aug,
+        face_crop_aug_range,
+        random_crop,
+        caption_dropout_rate,
+        caption_dropout_every_n_epochs,
+        caption_tag_dropout_rate,
+        caption_prefix,
+        caption_suffix,
+        token_warmup_min,
+        token_warmup_step,
+    ) -> None:
+        assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
+
+        super().__init__(
+            image_dir,
+            num_repeats,
+            shuffle_caption,
+            caption_separator,
+            keep_tokens,
+            keep_tokens_separator,
+            color_aug,
+            flip_aug,
+            face_crop_aug_range,
+            random_crop,
+            caption_dropout_rate,
+            caption_dropout_every_n_epochs,
+            caption_tag_dropout_rate,
+            caption_prefix,
+            caption_suffix,
+            token_warmup_min,
+            token_warmup_step,
+        )
+
+        self.conditioning_data_dir = conditioning_data_dir
+        self.caption_extension = caption_extension
+        if self.caption_extension and not self.caption_extension.startswith("."):
+            self.caption_extension = "." + self.caption_extension
+
+    def __eq__(self, other) -> bool:
+        if not isinstance(other, ControlNetSubset):
+            return NotImplemented
+        return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
+
+
+class BaseDataset(torch.utils.data.Dataset):
+    def __init__(
+        self,
+        tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
+        max_token_length: int,
+        resolution: Optional[Tuple[int, int]],
+        debug_dataset: bool,
+    ) -> None:
+        super().__init__()
+
+        self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
+
+        self.max_token_length = max_token_length
+        # width/height is used when enable_bucket==False
+        self.width, self.height = (None, None) if resolution is None else resolution
+        self.debug_dataset = debug_dataset
+
+        self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
+
+        self.token_padding_disabled = False
+        self.tag_frequency = {}
+        self.XTI_layers = None
+        self.token_strings = None
+
+        self.enable_bucket = False
+        self.bucket_manager: BucketManager = None  # not initialized
+        self.min_bucket_reso = None
+        self.max_bucket_reso = None
+        self.bucket_reso_steps = None
+        self.bucket_no_upscale = None
+        self.bucket_info = None  # for metadata
+
+        self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
+
+        self.current_epoch: int = 0  # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
+
+        self.current_step: int = 0
+        self.max_train_steps: int = 0
+        self.seed: int = 0
+
+        # augmentation
+        self.aug_helper = AugHelper()
+
+        self.image_transforms = IMAGE_TRANSFORMS
+
+        self.image_data: Dict[str, ImageInfo] = {}
+        self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
+
+        self.replacements = {}
+
+        # caching
+        self.caching_mode = None  # None, 'latents', 'text'
+
+    def set_seed(self, seed):
+        self.seed = seed
+
+    def set_caching_mode(self, mode):
+        self.caching_mode = mode
+
+    def set_current_epoch(self, epoch):
+        if not self.current_epoch == epoch:  # epochが切り替わったらバケツをシャッフルする
+            self.shuffle_buckets()
+        self.current_epoch = epoch
+
+    def set_current_step(self, step):
+        self.current_step = step
+
+    def set_max_train_steps(self, max_train_steps):
+        self.max_train_steps = max_train_steps
+
+    def set_tag_frequency(self, dir_name, captions):
+        frequency_for_dir = self.tag_frequency.get(dir_name, {})
+        self.tag_frequency[dir_name] = frequency_for_dir
+        for caption in captions:
+            for tag in caption.split(","):
+                tag = tag.strip()
+                if tag:
+                    tag = tag.lower()
+                    frequency = frequency_for_dir.get(tag, 0)
+                    frequency_for_dir[tag] = frequency + 1
+
+    def disable_token_padding(self):
+        self.token_padding_disabled = True
+
+    def enable_XTI(self, layers=None, token_strings=None):
+        self.XTI_layers = layers
+        self.token_strings = token_strings
+
+    def add_replacement(self, str_from, str_to):
+        self.replacements[str_from] = str_to
+
+    def process_caption(self, subset: BaseSubset, caption):
+        # caption に prefix/suffix を付ける
+        if subset.caption_prefix:
+            caption = subset.caption_prefix + " " + caption
+        if subset.caption_suffix:
+            caption = caption + " " + subset.caption_suffix
+
+        # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
+        is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
+        is_drop_out = (
+            is_drop_out
+            or subset.caption_dropout_every_n_epochs > 0
+            and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
+        )
+
+        if is_drop_out:
+            caption = ""
+        else:
+            if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
+                fixed_tokens = []
+                flex_tokens = []
+                if (
+                    hasattr(subset, "keep_tokens_separator")
+                    and subset.keep_tokens_separator
+                    and subset.keep_tokens_separator in caption
+                ):
+                    fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
+                    fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
+                    flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
+                else:
+                    tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
+                    flex_tokens = tokens[:]
+                    if subset.keep_tokens > 0:
+                        fixed_tokens = flex_tokens[: subset.keep_tokens]
+                        flex_tokens = tokens[subset.keep_tokens :]
+
+                if subset.token_warmup_step < 1:  # 初回に上書きする
+                    subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
+                if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
+                    tokens_len = (
+                        math.floor(
+                            (self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
+                        )
+                        + subset.token_warmup_min
+                    )
+                    flex_tokens = flex_tokens[:tokens_len]
+
+                def dropout_tags(tokens):
+                    if subset.caption_tag_dropout_rate <= 0:
+                        return tokens
+                    l = []
+                    for token in tokens:
+                        if random.random() >= subset.caption_tag_dropout_rate:
+                            l.append(token)
+                    return l
+
+                if subset.shuffle_caption:
+                    random.shuffle(flex_tokens)
+
+                flex_tokens = dropout_tags(flex_tokens)
+
+                caption = ", ".join(fixed_tokens + flex_tokens)
+
+            # textual inversion対応
+            for str_from, str_to in self.replacements.items():
+                if str_from == "":
+                    # replace all
+                    if type(str_to) == list:
+                        caption = random.choice(str_to)
+                    else:
+                        caption = str_to
+                else:
+                    caption = caption.replace(str_from, str_to)
+
+        return caption
+
+    def get_input_ids(self, caption, tokenizer=None):
+        if tokenizer is None:
+            tokenizer = self.tokenizers[0]
+
+        input_ids = tokenizer(
+            caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
+        ).input_ids
+
+        if self.tokenizer_max_length > tokenizer.model_max_length:
+            input_ids = input_ids.squeeze(0)
+            iids_list = []
+            if tokenizer.pad_token_id == tokenizer.eos_token_id:
+                # v1
+                # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
+                # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
+                for i in range(
+                    1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2
+                ):  # (1, 152, 75)
+                    ids_chunk = (
+                        input_ids[0].unsqueeze(0),
+                        input_ids[i : i + tokenizer.model_max_length - 2],
+                        input_ids[-1].unsqueeze(0),
+                    )
+                    ids_chunk = torch.cat(ids_chunk)
+                    iids_list.append(ids_chunk)
+            else:
+                # v2 or SDXL
+                # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
+                for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
+                    ids_chunk = (
+                        input_ids[0].unsqueeze(0),  # BOS
+                        input_ids[i : i + tokenizer.model_max_length - 2],
+                        input_ids[-1].unsqueeze(0),
+                    )  # PAD or EOS
+                    ids_chunk = torch.cat(ids_chunk)
+
+                    # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
+                    # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
+                    if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
+                        ids_chunk[-1] = tokenizer.eos_token_id
+                    # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
+                    if ids_chunk[1] == tokenizer.pad_token_id:
+                        ids_chunk[1] = tokenizer.eos_token_id
+
+                    iids_list.append(ids_chunk)
+
+            input_ids = torch.stack(iids_list)  # 3,77
+        return input_ids
+
+    def register_image(self, info: ImageInfo, subset: BaseSubset):
+        self.image_data[info.image_key] = info
+        self.image_to_subset[info.image_key] = subset
+
+    def make_buckets(self):
+        """
+        bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
+        min_size and max_size are ignored when enable_bucket is False
+        """
+        print("loading image sizes.")
+        for info in tqdm(self.image_data.values()):
+            if info.image_size is None:
+                info.image_size = self.get_image_size(info.absolute_path)
+
+        if self.enable_bucket:
+            print("make buckets")
+        else:
+            print("prepare dataset")
+
+        # bucketを作成し、画像をbucketに振り分ける
+        if self.enable_bucket:
+            if self.bucket_manager is None:  # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
+                self.bucket_manager = BucketManager(
+                    self.bucket_no_upscale,
+                    (self.width, self.height),
+                    self.min_bucket_reso,
+                    self.max_bucket_reso,
+                    self.bucket_reso_steps,
+                )
+                if not self.bucket_no_upscale:
+                    self.bucket_manager.make_buckets()
+                else:
+                    print(
+                        "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
+                    )
+
+            img_ar_errors = []
+            for image_info in self.image_data.values():
+                image_width, image_height = image_info.image_size
+                image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
+                    image_width, image_height
+                )
+
+                # print(image_info.image_key, image_info.bucket_reso)
+                img_ar_errors.append(abs(ar_error))
+
+            self.bucket_manager.sort()
+        else:
+            self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
+            self.bucket_manager.set_predefined_resos([(self.width, self.height)])  # ひとつの固定サイズbucketのみ
+            for image_info in self.image_data.values():
+                image_width, image_height = image_info.image_size
+                image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
+
+        for image_info in self.image_data.values():
+            for _ in range(image_info.num_repeats):
+                self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
+
+        # bucket情報を表示、格納する
+        if self.enable_bucket:
+            self.bucket_info = {"buckets": {}}
+            print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
+            for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
+                count = len(bucket)
+                if count > 0:
+                    self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
+                    print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
+
+            img_ar_errors = np.array(img_ar_errors)
+            mean_img_ar_error = np.mean(np.abs(img_ar_errors))
+            self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
+            print(f"mean ar error (without repeats): {mean_img_ar_error}")
+
+        # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
+        self.buckets_indices: List(BucketBatchIndex) = []
+        for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
+            batch_count = int(math.ceil(len(bucket) / self.batch_size))
+            for batch_index in range(batch_count):
+                self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
+
+            # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
+            #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
+            #
+            # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
+            # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
+            # # そのためバッチサイズを画像種類までに制限する
+            # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
+            # # TO DO 正則化画像をepochまたがりで利用する仕組み
+            # num_of_image_types = len(set(bucket))
+            # bucket_batch_size = min(self.batch_size, num_of_image_types)
+            # batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
+            # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
+            # for batch_index in range(batch_count):
+            #   self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
+            # ↑ここまで
+
+        self.shuffle_buckets()
+        self._length = len(self.buckets_indices)
+
+    def shuffle_buckets(self):
+        # set random seed for this epoch
+        random.seed(self.seed + self.current_epoch)
+
+        random.shuffle(self.buckets_indices)
+        self.bucket_manager.shuffle()
+
+    def verify_bucket_reso_steps(self, min_steps: int):
+        assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, (
+            f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
+            + f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
+        )
+
+    def is_latent_cacheable(self):
+        return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
+
+    def is_text_encoder_output_cacheable(self):
+        return all(
+            [
+                not (
+                    subset.caption_dropout_rate > 0
+                    or subset.shuffle_caption
+                    or subset.token_warmup_step > 0
+                    or subset.caption_tag_dropout_rate > 0
+                )
+                for subset in self.subsets
+            ]
+        )
+
+    def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
+        # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
+        print("caching latents.")
+
+        image_infos = list(self.image_data.values())
+
+        # sort by resolution
+        image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
+
+        # split by resolution
+        batches = []
+        batch = []
+        print("checking cache validity...")
+        for info in tqdm(image_infos):
+            subset = self.image_to_subset[info.image_key]
+
+            if info.latents_npz is not None:  # fine tuning dataset
+                continue
+
+            # check disk cache exists and size of latents
+            if cache_to_disk:
+                info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
+                if not is_main_process:  # store to info only
+                    continue
+
+                cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
+
+                if cache_available:  # do not add to batch
+                    continue
+
+            # if last member of batch has different resolution, flush the batch
+            if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
+                batches.append(batch)
+                batch = []
+
+            batch.append(info)
+
+            # if number of data in batch is enough, flush the batch
+            if len(batch) >= vae_batch_size:
+                batches.append(batch)
+                batch = []
+
+        if len(batch) > 0:
+            batches.append(batch)
+
+        if cache_to_disk and not is_main_process:  # if cache to disk, don't cache latents in non-main process, set to info only
+            return
+
+        # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
+        print("caching latents...")
+        for batch in tqdm(batches, smoothing=1, total=len(batches)):
+            cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
+
+    # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
+    # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
+    # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
+    def cache_text_encoder_outputs(
+        self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
+    ):
+        assert len(tokenizers) == 2, "only support SDXL"
+
+        # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
+        # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
+        print("caching text encoder outputs.")
+        image_infos = list(self.image_data.values())
+
+        print("checking cache existence...")
+        image_infos_to_cache = []
+        for info in tqdm(image_infos):
+            # subset = self.image_to_subset[info.image_key]
+            if cache_to_disk:
+                te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
+                info.text_encoder_outputs_npz = te_out_npz
+
+                if not is_main_process:  # store to info only
+                    continue
+
+                if os.path.exists(te_out_npz):
+                    continue
+
+            image_infos_to_cache.append(info)
+
+        if cache_to_disk and not is_main_process:  # if cache to disk, don't cache latents in non-main process, set to info only
+            return
+
+        # prepare tokenizers and text encoders
+        for text_encoder in text_encoders:
+            text_encoder.to(device)
+            if weight_dtype is not None:
+                text_encoder.to(dtype=weight_dtype)
+
+        # create batch
+        batch = []
+        batches = []
+        for info in image_infos_to_cache:
+            input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
+            input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
+            batch.append((info, input_ids1, input_ids2))
+
+            if len(batch) >= self.batch_size:
+                batches.append(batch)
+                batch = []
+
+        if len(batch) > 0:
+            batches.append(batch)
+
+        # iterate batches: call text encoder and cache outputs for memory or disk
+        print("caching text encoder outputs...")
+        for batch in tqdm(batches):
+            infos, input_ids1, input_ids2 = zip(*batch)
+            input_ids1 = torch.stack(input_ids1, dim=0)
+            input_ids2 = torch.stack(input_ids2, dim=0)
+            cache_batch_text_encoder_outputs(
+                infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
+            )
+
+    def get_image_size(self, image_path):
+        image = Image.open(image_path)
+        return image.size
+
+    def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
+        img = load_image(image_path)
+
+        face_cx = face_cy = face_w = face_h = 0
+        if subset.face_crop_aug_range is not None:
+            tokens = os.path.splitext(os.path.basename(image_path))[0].split("_")
+            if len(tokens) >= 5:
+                face_cx = int(tokens[-4])
+                face_cy = int(tokens[-3])
+                face_w = int(tokens[-2])
+                face_h = int(tokens[-1])
+
+        return img, face_cx, face_cy, face_w, face_h
+
+    # いい感じに切り出す
+    def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
+        height, width = image.shape[0:2]
+        if height == self.height and width == self.width:
+            return image
+
+        # 画像サイズはsizeより大きいのでリサイズする
+        face_size = max(face_w, face_h)
+        size = min(self.height, self.width)  # 短いほう
+        min_scale = max(self.height / height, self.width / width)  # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
+        min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1])))  # 指定した顔最小サイズ
+        max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0])))  # 指定した顔最大サイズ
+        if min_scale >= max_scale:  # range指定がmin==max
+            scale = min_scale
+        else:
+            scale = random.uniform(min_scale, max_scale)
+
+        nh = int(height * scale + 0.5)
+        nw = int(width * scale + 0.5)
+        assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
+        image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
+        face_cx = int(face_cx * scale + 0.5)
+        face_cy = int(face_cy * scale + 0.5)
+        height, width = nh, nw
+
+        # 顔を中心として448*640とかへ切り出す
+        for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
+            p1 = face_p - target_size // 2  # 顔を中心に持ってくるための切り出し位置
+
+            if subset.random_crop:
+                # 背景も含めるために顔を中心に置く確率を高めつつずらす
+                range = max(length - face_p, face_p)  # 画像の端から顔中心までの距離の長いほう
+                p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range  # -range ~ +range までのいい感じの乱数
+            else:
+                # range指定があるときのみ、すこしだけランダムに(わりと適当)
+                if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
+                    if face_size > size // 10 and face_size >= 40:
+                        p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
+
+            p1 = max(0, min(p1, length - target_size))
+
+            if axis == 0:
+                image = image[p1 : p1 + target_size, :]
+            else:
+                image = image[:, p1 : p1 + target_size]
+
+        return image
+
+    def __len__(self):
+        return self._length
+
+    def __getitem__(self, index):
+        bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
+        bucket_batch_size = self.buckets_indices[index].bucket_batch_size
+        image_index = self.buckets_indices[index].batch_index * bucket_batch_size
+
+        if self.caching_mode is not None:  # return batch for latents/text encoder outputs caching
+            return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
+
+        loss_weights = []
+        captions = []
+        input_ids_list = []
+        input_ids2_list = []
+        latents_list = []
+        images = []
+        original_sizes_hw = []
+        crop_top_lefts = []
+        target_sizes_hw = []
+        flippeds = []  # 変数名が微妙
+        text_encoder_outputs1_list = []
+        text_encoder_outputs2_list = []
+        text_encoder_pool2_list = []
+
+        for image_key in bucket[image_index : image_index + bucket_batch_size]:
+            image_info = self.image_data[image_key]
+            subset = self.image_to_subset[image_key]
+            loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
+
+            flipped = subset.flip_aug and random.random() < 0.5  # not flipped or flipped with 50% chance
+
+            # image/latentsを処理する
+            if image_info.latents is not None:  # cache_latents=Trueの場合
+                original_size = image_info.latents_original_size
+                crop_ltrb = image_info.latents_crop_ltrb  # calc values later if flipped
+                if not flipped:
+                    latents = image_info.latents
+                else:
+                    latents = image_info.latents_flipped
+
+                image = None
+            elif image_info.latents_npz is not None:  # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
+                latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
+                if flipped:
+                    latents = flipped_latents
+                    del flipped_latents
+                latents = torch.FloatTensor(latents)
+
+                image = None
+            else:
+                # 画像を読み込み、必要ならcropする
+                img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
+                im_h, im_w = img.shape[0:2]
+
+                if self.enable_bucket:
+                    img, original_size, crop_ltrb = trim_and_resize_if_required(
+                        subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
+                    )
+                else:
+                    if face_cx > 0:  # 顔位置情報あり
+                        img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
+                    elif im_h > self.height or im_w > self.width:
+                        assert (
+                            subset.random_crop
+                        ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
+                        if im_h > self.height:
+                            p = random.randint(0, im_h - self.height)
+                            img = img[p : p + self.height]
+                        if im_w > self.width:
+                            p = random.randint(0, im_w - self.width)
+                            img = img[:, p : p + self.width]
+
+                    im_h, im_w = img.shape[0:2]
+                    assert (
+                        im_h == self.height and im_w == self.width
+                    ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
+
+                    original_size = [im_w, im_h]
+                    crop_ltrb = (0, 0, 0, 0)
+
+                # augmentation
+                aug = self.aug_helper.get_augmentor(subset.color_aug)
+                if aug is not None:
+                    img = aug(image=img)["image"]
+
+                if flipped:
+                    img = img[:, ::-1, :].copy()  # copy to avoid negative stride problem
+
+                latents = None
+                image = self.image_transforms(img)  # -1.0~1.0のtorch.Tensorになる
+
+            images.append(image)
+            latents_list.append(latents)
+
+            target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
+
+            if not flipped:
+                crop_left_top = (crop_ltrb[0], crop_ltrb[1])
+            else:
+                # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
+                crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
+
+            original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
+            crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
+            target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
+            flippeds.append(flipped)
+
+            # captionとtext encoder outputを処理する
+            caption = image_info.caption  # default
+            if image_info.text_encoder_outputs1 is not None:
+                text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
+                text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
+                text_encoder_pool2_list.append(image_info.text_encoder_pool2)
+                captions.append(caption)
+            elif image_info.text_encoder_outputs_npz is not None:
+                text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
+                    image_info.text_encoder_outputs_npz
+                )
+                text_encoder_outputs1_list.append(text_encoder_outputs1)
+                text_encoder_outputs2_list.append(text_encoder_outputs2)
+                text_encoder_pool2_list.append(text_encoder_pool2)
+                captions.append(caption)
+            else:
+                caption = self.process_caption(subset, image_info.caption)
+                if self.XTI_layers:
+                    caption_layer = []
+                    for layer in self.XTI_layers:
+                        token_strings_from = " ".join(self.token_strings)
+                        token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
+                        caption_ = caption.replace(token_strings_from, token_strings_to)
+                        caption_layer.append(caption_)
+                    captions.append(caption_layer)
+                else:
+                    captions.append(caption)
+
+                if not self.token_padding_disabled:  # this option might be omitted in future
+                    if self.XTI_layers:
+                        token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
+                    else:
+                        token_caption = self.get_input_ids(caption, self.tokenizers[0])
+                    input_ids_list.append(token_caption)
+
+                    if len(self.tokenizers) > 1:
+                        if self.XTI_layers:
+                            token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
+                        else:
+                            token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
+                        input_ids2_list.append(token_caption2)
+
+        example = {}
+        example["loss_weights"] = torch.FloatTensor(loss_weights)
+
+        if len(text_encoder_outputs1_list) == 0:
+            if self.token_padding_disabled:
+                # padding=True means pad in the batch
+                example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
+                if len(self.tokenizers) > 1:
+                    example["input_ids2"] = self.tokenizer[1](
+                        captions, padding=True, truncation=True, return_tensors="pt"
+                    ).input_ids
+                else:
+                    example["input_ids2"] = None
+            else:
+                example["input_ids"] = torch.stack(input_ids_list)
+                example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
+            example["text_encoder_outputs1_list"] = None
+            example["text_encoder_outputs2_list"] = None
+            example["text_encoder_pool2_list"] = None
+        else:
+            example["input_ids"] = None
+            example["input_ids2"] = None
+            # # for assertion
+            # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
+            # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
+            example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
+            example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
+            example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
+
+        if images[0] is not None:
+            images = torch.stack(images)
+            images = images.to(memory_format=torch.contiguous_format).float()
+        else:
+            images = None
+        example["images"] = images
+
+        example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
+        example["captions"] = captions
+
+        example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
+        example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts])
+        example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
+        example["flippeds"] = flippeds
+
+        if self.debug_dataset:
+            example["image_keys"] = bucket[image_index : image_index + self.batch_size]
+        return example
+
+    def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
+        captions = []
+        images = []
+        input_ids1_list = []
+        input_ids2_list = []
+        absolute_paths = []
+        resized_sizes = []
+        bucket_reso = None
+        flip_aug = None
+        random_crop = None
+
+        for image_key in bucket[image_index : image_index + bucket_batch_size]:
+            image_info = self.image_data[image_key]
+            subset = self.image_to_subset[image_key]
+
+            if flip_aug is None:
+                flip_aug = subset.flip_aug
+                random_crop = subset.random_crop
+                bucket_reso = image_info.bucket_reso
+            else:
+                assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
+                assert random_crop == subset.random_crop, "random_crop must be same in a batch"
+                assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
+
+            caption = image_info.caption  # TODO cache some patterns of dropping, shuffling, etc.
+
+            if self.caching_mode == "latents":
+                image = load_image(image_info.absolute_path)
+            else:
+                image = None
+
+            if self.caching_mode == "text":
+                input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
+                input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
+            else:
+                input_ids1 = None
+                input_ids2 = None
+
+            captions.append(caption)
+            images.append(image)
+            input_ids1_list.append(input_ids1)
+            input_ids2_list.append(input_ids2)
+            absolute_paths.append(image_info.absolute_path)
+            resized_sizes.append(image_info.resized_size)
+
+        example = {}
+
+        if images[0] is None:
+            images = None
+        example["images"] = images
+
+        example["captions"] = captions
+        example["input_ids1_list"] = input_ids1_list
+        example["input_ids2_list"] = input_ids2_list
+        example["absolute_paths"] = absolute_paths
+        example["resized_sizes"] = resized_sizes
+        example["flip_aug"] = flip_aug
+        example["random_crop"] = random_crop
+        example["bucket_reso"] = bucket_reso
+        return example
+
+
+class DreamBoothDataset(BaseDataset):
+    def __init__(
+        self,
+        subsets: Sequence[DreamBoothSubset],
+        batch_size: int,
+        tokenizer,
+        max_token_length,
+        resolution,
+        enable_bucket: bool,
+        min_bucket_reso: int,
+        max_bucket_reso: int,
+        bucket_reso_steps: int,
+        bucket_no_upscale: bool,
+        prior_loss_weight: float,
+        debug_dataset,
+    ) -> None:
+        super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
+
+        assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
+
+        self.batch_size = batch_size
+        self.size = min(self.width, self.height)  # 短いほう
+        self.prior_loss_weight = prior_loss_weight
+        self.latents_cache = None
+
+        self.enable_bucket = enable_bucket
+        if self.enable_bucket:
+            assert (
+                min(resolution) >= min_bucket_reso
+            ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
+            assert (
+                max(resolution) <= max_bucket_reso
+            ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
+            self.min_bucket_reso = min_bucket_reso
+            self.max_bucket_reso = max_bucket_reso
+            self.bucket_reso_steps = bucket_reso_steps
+            self.bucket_no_upscale = bucket_no_upscale
+        else:
+            self.min_bucket_reso = None
+            self.max_bucket_reso = None
+            self.bucket_reso_steps = None  # この情報は使われない
+            self.bucket_no_upscale = False
+
+        def read_caption(img_path, caption_extension):
+            # captionの候補ファイル名を作る
+            base_name = os.path.splitext(img_path)[0]
+            base_name_face_det = base_name
+            tokens = base_name.split("_")
+            if len(tokens) >= 5:
+                base_name_face_det = "_".join(tokens[:-4])
+            cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
+
+            caption = None
+            for cap_path in cap_paths:
+                if os.path.isfile(cap_path):
+                    with open(cap_path, "rt", encoding="utf-8") as f:
+                        try:
+                            lines = f.readlines()
+                        except UnicodeDecodeError as e:
+                            print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
+                            raise e
+                        assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
+                        caption = lines[0].strip()
+                    break
+            return caption
+
+        def load_dreambooth_dir(subset: DreamBoothSubset):
+            if not os.path.isdir(subset.image_dir):
+                print(f"not directory: {subset.image_dir}")
+                return [], []
+
+            img_paths = glob_images(subset.image_dir, "*")
+            print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
+
+            # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
+            captions = []
+            missing_captions = []
+            for img_path in img_paths:
+                cap_for_img = read_caption(img_path, subset.caption_extension)
+                if cap_for_img is None and subset.class_tokens is None:
+                    print(
+                        f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
+                    )
+                    captions.append("")
+                    missing_captions.append(img_path)
+                else:
+                    if cap_for_img is None:
+                        captions.append(subset.class_tokens)
+                        missing_captions.append(img_path)
+                    else:
+                        captions.append(cap_for_img)
+
+            self.set_tag_frequency(os.path.basename(subset.image_dir), captions)  # タグ頻度を記録
+
+            if missing_captions:
+                number_of_missing_captions = len(missing_captions)
+                number_of_missing_captions_to_show = 5
+                remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
+
+                print(
+                    f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
+                )
+                for i, missing_caption in enumerate(missing_captions):
+                    if i >= number_of_missing_captions_to_show:
+                        print(missing_caption + f"... and {remaining_missing_captions} more")
+                        break
+                    print(missing_caption)
+            return img_paths, captions
+
+        print("prepare images.")
+        num_train_images = 0
+        num_reg_images = 0
+        reg_infos: List[ImageInfo] = []
+        for subset in subsets:
+            if subset.num_repeats < 1:
+                print(
+                    f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
+                )
+                continue
+
+            if subset in self.subsets:
+                print(
+                    f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
+                )
+                continue
+
+            img_paths, captions = load_dreambooth_dir(subset)
+            if len(img_paths) < 1:
+                print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
+                continue
+
+            if subset.is_reg:
+                num_reg_images += subset.num_repeats * len(img_paths)
+            else:
+                num_train_images += subset.num_repeats * len(img_paths)
+
+            for img_path, caption in zip(img_paths, captions):
+                info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
+                if subset.is_reg:
+                    reg_infos.append(info)
+                else:
+                    self.register_image(info, subset)
+
+            subset.img_count = len(img_paths)
+            self.subsets.append(subset)
+
+        print(f"{num_train_images} train images with repeating.")
+        self.num_train_images = num_train_images
+
+        print(f"{num_reg_images} reg images.")
+        if num_train_images < num_reg_images:
+            print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
+
+        if num_reg_images == 0:
+            print("no regularization images / 正則化画像が見つかりませんでした")
+        else:
+            # num_repeatsを計算する:どうせ大した数ではないのでループで処理する
+            n = 0
+            first_loop = True
+            while n < num_train_images:
+                for info in reg_infos:
+                    if first_loop:
+                        self.register_image(info, subset)
+                        n += info.num_repeats
+                    else:
+                        info.num_repeats += 1  # rewrite registered info
+                        n += 1
+                    if n >= num_train_images:
+                        break
+                first_loop = False
+
+        self.num_reg_images = num_reg_images
+
+
+class FineTuningDataset(BaseDataset):
+    def __init__(
+        self,
+        subsets: Sequence[FineTuningSubset],
+        batch_size: int,
+        tokenizer,
+        max_token_length,
+        resolution,
+        enable_bucket: bool,
+        min_bucket_reso: int,
+        max_bucket_reso: int,
+        bucket_reso_steps: int,
+        bucket_no_upscale: bool,
+        debug_dataset,
+    ) -> None:
+        super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
+
+        self.batch_size = batch_size
+
+        self.num_train_images = 0
+        self.num_reg_images = 0
+
+        for subset in subsets:
+            if subset.num_repeats < 1:
+                print(
+                    f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
+                )
+                continue
+
+            if subset in self.subsets:
+                print(
+                    f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
+                )
+                continue
+
+            # メタデータを読み込む
+            if os.path.exists(subset.metadata_file):
+                print(f"loading existing metadata: {subset.metadata_file}")
+                with open(subset.metadata_file, "rt", encoding="utf-8") as f:
+                    metadata = json.load(f)
+            else:
+                raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
+
+            if len(metadata) < 1:
+                print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
+                continue
+
+            tags_list = []
+            for image_key, img_md in metadata.items():
+                # path情報を作る
+                abs_path = None
+
+                # まず画像を優先して探す
+                if os.path.exists(image_key):
+                    abs_path = image_key
+                else:
+                    # わりといい加減だがいい方法が思いつかん
+                    paths = glob_images(subset.image_dir, image_key)
+                    if len(paths) > 0:
+                        abs_path = paths[0]
+
+                # なければnpzを探す
+                if abs_path is None:
+                    if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
+                        abs_path = os.path.splitext(image_key)[0] + ".npz"
+                    else:
+                        npz_path = os.path.join(subset.image_dir, image_key + ".npz")
+                        if os.path.exists(npz_path):
+                            abs_path = npz_path
+
+                assert abs_path is not None, f"no image / 画像がありません: {image_key}"
+
+                caption = img_md.get("caption")
+                tags = img_md.get("tags")
+                if caption is None:
+                    caption = tags
+                elif tags is not None and len(tags) > 0:
+                    caption = caption + ", " + tags
+                    tags_list.append(tags)
+
+                if caption is None:
+                    caption = ""
+
+                image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
+                image_info.image_size = img_md.get("train_resolution")
+
+                if not subset.color_aug and not subset.random_crop:
+                    # if npz exists, use them
+                    image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
+
+                self.register_image(image_info, subset)
+
+            self.num_train_images += len(metadata) * subset.num_repeats
+
+            # TODO do not record tag freq when no tag
+            self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
+            subset.img_count = len(metadata)
+            self.subsets.append(subset)
+
+        # check existence of all npz files
+        use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
+        if use_npz_latents:
+            flip_aug_in_subset = False
+            npz_any = False
+            npz_all = True
+
+            for image_info in self.image_data.values():
+                subset = self.image_to_subset[image_info.image_key]
+
+                has_npz = image_info.latents_npz is not None
+                npz_any = npz_any or has_npz
+
+                if subset.flip_aug:
+                    has_npz = has_npz and image_info.latents_npz_flipped is not None
+                    flip_aug_in_subset = True
+                npz_all = npz_all and has_npz
+
+                if npz_any and not npz_all:
+                    break
+
+            if not npz_any:
+                use_npz_latents = False
+                print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
+            elif not npz_all:
+                use_npz_latents = False
+                print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
+                if flip_aug_in_subset:
+                    print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
+        # else:
+        #   print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
+
+        # check min/max bucket size
+        sizes = set()
+        resos = set()
+        for image_info in self.image_data.values():
+            if image_info.image_size is None:
+                sizes = None  # not calculated
+                break
+            sizes.add(image_info.image_size[0])
+            sizes.add(image_info.image_size[1])
+            resos.add(tuple(image_info.image_size))
+
+        if sizes is None:
+            if use_npz_latents:
+                use_npz_latents = False
+                print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
+
+            assert (
+                resolution is not None
+            ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
+
+            self.enable_bucket = enable_bucket
+            if self.enable_bucket:
+                self.min_bucket_reso = min_bucket_reso
+                self.max_bucket_reso = max_bucket_reso
+                self.bucket_reso_steps = bucket_reso_steps
+                self.bucket_no_upscale = bucket_no_upscale
+        else:
+            if not enable_bucket:
+                print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
+            print("using bucket info in metadata / メタデータ内のbucket情報を使います")
+            self.enable_bucket = True
+
+            assert (
+                not bucket_no_upscale
+            ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
+
+            # bucket情報を初期化しておく、make_bucketsで再作成しない
+            self.bucket_manager = BucketManager(False, None, None, None, None)
+            self.bucket_manager.set_predefined_resos(resos)
+
+        # npz情報をきれいにしておく
+        if not use_npz_latents:
+            for image_info in self.image_data.values():
+                image_info.latents_npz = image_info.latents_npz_flipped = None
+
+    def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
+        base_name = os.path.splitext(image_key)[0]
+        npz_file_norm = base_name + ".npz"
+
+        if os.path.exists(npz_file_norm):
+            # image_key is full path
+            npz_file_flip = base_name + "_flip.npz"
+            if not os.path.exists(npz_file_flip):
+                npz_file_flip = None
+            return npz_file_norm, npz_file_flip
+
+        # if not full path, check image_dir. if image_dir is None, return None
+        if subset.image_dir is None:
+            return None, None
+
+        # image_key is relative path
+        npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
+        npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
+
+        if not os.path.exists(npz_file_norm):
+            npz_file_norm = None
+            npz_file_flip = None
+        elif not os.path.exists(npz_file_flip):
+            npz_file_flip = None
+
+        return npz_file_norm, npz_file_flip
+
+
+class ControlNetDataset(BaseDataset):
+    def __init__(
+        self,
+        subsets: Sequence[ControlNetSubset],
+        batch_size: int,
+        tokenizer,
+        max_token_length,
+        resolution,
+        enable_bucket: bool,
+        min_bucket_reso: int,
+        max_bucket_reso: int,
+        bucket_reso_steps: int,
+        bucket_no_upscale: bool,
+        debug_dataset,
+    ) -> None:
+        super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
+
+        db_subsets = []
+        for subset in subsets:
+            db_subset = DreamBoothSubset(
+                subset.image_dir,
+                False,
+                None,
+                subset.caption_extension,
+                subset.num_repeats,
+                subset.shuffle_caption,
+                subset.caption_separator,
+                subset.keep_tokens,
+                subset.keep_tokens_separator,
+                subset.color_aug,
+                subset.flip_aug,
+                subset.face_crop_aug_range,
+                subset.random_crop,
+                subset.caption_dropout_rate,
+                subset.caption_dropout_every_n_epochs,
+                subset.caption_tag_dropout_rate,
+                subset.caption_prefix,
+                subset.caption_suffix,
+                subset.token_warmup_min,
+                subset.token_warmup_step,
+            )
+            db_subsets.append(db_subset)
+
+        self.dreambooth_dataset_delegate = DreamBoothDataset(
+            db_subsets,
+            batch_size,
+            tokenizer,
+            max_token_length,
+            resolution,
+            enable_bucket,
+            min_bucket_reso,
+            max_bucket_reso,
+            bucket_reso_steps,
+            bucket_no_upscale,
+            1.0,
+            debug_dataset,
+        )
+
+        # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
+        self.image_data = self.dreambooth_dataset_delegate.image_data
+        self.batch_size = batch_size
+        self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
+        self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
+
+        # assert all conditioning data exists
+        missing_imgs = []
+        cond_imgs_with_img = set()
+        for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
+            db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
+            subset = None
+            for s in subsets:
+                if s.image_dir == db_subset.image_dir:
+                    subset = s
+                    break
+            assert subset is not None, "internal error: subset not found"
+
+            if not os.path.isdir(subset.conditioning_data_dir):
+                print(f"not directory: {subset.conditioning_data_dir}")
+                continue
+
+            img_basename = os.path.basename(info.absolute_path)
+            ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
+            if not os.path.exists(ctrl_img_path):
+                missing_imgs.append(img_basename)
+
+            info.cond_img_path = ctrl_img_path
+            cond_imgs_with_img.add(ctrl_img_path)
+
+        extra_imgs = []
+        for subset in subsets:
+            conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
+            extra_imgs.extend(
+                [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
+            )
+
+        assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
+        assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
+
+        self.conditioning_image_transforms = IMAGE_TRANSFORMS
+
+    def make_buckets(self):
+        self.dreambooth_dataset_delegate.make_buckets()
+        self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
+        self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
+
+    def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
+        return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
+
+    def __len__(self):
+        return self.dreambooth_dataset_delegate.__len__()
+
+    def __getitem__(self, index):
+        example = self.dreambooth_dataset_delegate[index]
+
+        bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[
+            self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index
+        ]
+        bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size
+        image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size
+
+        conditioning_images = []
+
+        for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]):
+            image_info = self.dreambooth_dataset_delegate.image_data[image_key]
+
+            target_size_hw = example["target_sizes_hw"][i]
+            original_size_hw = example["original_sizes_hw"][i]
+            crop_top_left = example["crop_top_lefts"][i]
+            flipped = example["flippeds"][i]
+            cond_img = load_image(image_info.cond_img_path)
+
+            if self.dreambooth_dataset_delegate.enable_bucket:
+                assert (
+                    cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
+                ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
+                cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
+
+                # TODO support random crop
+                # 現在サポートしているcropはrandomではなく中央のみ
+                h, w = target_size_hw
+                ct = (cond_img.shape[0] - h) // 2
+                cl = (cond_img.shape[1] - w) // 2
+                cond_img = cond_img[ct : ct + h, cl : cl + w]
+            else:
+                # assert (
+                #     cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
+                # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
+                # resize to target
+                if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
+                    cond_img = cv2.resize(
+                        cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
+                    )
+
+            if flipped:
+                cond_img = cond_img[:, ::-1, :].copy()  # copy to avoid negative stride
+
+            cond_img = self.conditioning_image_transforms(cond_img)
+            conditioning_images.append(cond_img)
+
+        example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
+
+        return example
+
+
+# behave as Dataset mock
+class DatasetGroup(torch.utils.data.ConcatDataset):
+    def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
+        self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
+
+        super().__init__(datasets)
+
+        self.image_data = {}
+        self.num_train_images = 0
+        self.num_reg_images = 0
+
+        # simply concat together
+        # TODO: handling image_data key duplication among dataset
+        #   In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
+        for dataset in datasets:
+            self.image_data.update(dataset.image_data)
+            self.num_train_images += dataset.num_train_images
+            self.num_reg_images += dataset.num_reg_images
+
+    def add_replacement(self, str_from, str_to):
+        for dataset in self.datasets:
+            dataset.add_replacement(str_from, str_to)
+
+    # def make_buckets(self):
+    #   for dataset in self.datasets:
+    #     dataset.make_buckets()
+
+    def enable_XTI(self, *args, **kwargs):
+        for dataset in self.datasets:
+            dataset.enable_XTI(*args, **kwargs)
+
+    def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
+        for i, dataset in enumerate(self.datasets):
+            print(f"[Dataset {i}]")
+            dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
+
+    def cache_text_encoder_outputs(
+        self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
+    ):
+        for i, dataset in enumerate(self.datasets):
+            print(f"[Dataset {i}]")
+            dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
+
+    def set_caching_mode(self, caching_mode):
+        for dataset in self.datasets:
+            dataset.set_caching_mode(caching_mode)
+
+    def verify_bucket_reso_steps(self, min_steps: int):
+        for dataset in self.datasets:
+            dataset.verify_bucket_reso_steps(min_steps)
+
+    def is_latent_cacheable(self) -> bool:
+        return all([dataset.is_latent_cacheable() for dataset in self.datasets])
+
+    def is_text_encoder_output_cacheable(self) -> bool:
+        return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
+
+    def set_current_epoch(self, epoch):
+        for dataset in self.datasets:
+            dataset.set_current_epoch(epoch)
+
+    def set_current_step(self, step):
+        for dataset in self.datasets:
+            dataset.set_current_step(step)
+
+    def set_max_train_steps(self, max_train_steps):
+        for dataset in self.datasets:
+            dataset.set_max_train_steps(max_train_steps)
+
+    def disable_token_padding(self):
+        for dataset in self.datasets:
+            dataset.disable_token_padding()
+
+
+def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
+    expected_latents_size = (reso[1] // 8, reso[0] // 8)  # bucket_resoはWxHなので注意
+
+    if not os.path.exists(npz_path):
+        return False
+
+    npz = np.load(npz_path)
+    if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz:  # old ver?
+        return False
+    if npz["latents"].shape[1:3] != expected_latents_size:
+        return False
+
+    if flip_aug:
+        if "latents_flipped" not in npz:
+            return False
+        if npz["latents_flipped"].shape[1:3] != expected_latents_size:
+            return False
+
+    return True
+
+
+# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
+def load_latents_from_disk(
+    npz_path,
+) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
+    npz = np.load(npz_path)
+    if "latents" not in npz:
+        raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
+
+    latents = npz["latents"]
+    original_size = npz["original_size"].tolist()
+    crop_ltrb = npz["crop_ltrb"].tolist()
+    flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
+    return latents, original_size, crop_ltrb, flipped_latents
+
+
+def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
+    kwargs = {}
+    if flipped_latents_tensor is not None:
+        kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
+    np.savez(
+        npz_path,
+        latents=latents_tensor.float().cpu().numpy(),
+        original_size=np.array(original_size),
+        crop_ltrb=np.array(crop_ltrb),
+        **kwargs,
+    )
+
+
+def debug_dataset(train_dataset, show_input_ids=False):
+    print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
+    print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
+
+    epoch = 1
+    while True:
+        print(f"\nepoch: {epoch}")
+
+        steps = (epoch - 1) * len(train_dataset) + 1
+        indices = list(range(len(train_dataset)))
+        random.shuffle(indices)
+
+        k = 0
+        for i, idx in enumerate(indices):
+            train_dataset.set_current_epoch(epoch)
+            train_dataset.set_current_step(steps)
+            print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
+
+            example = train_dataset[idx]
+            if example["latents"] is not None:
+                print(f"sample has latents from npz file: {example['latents'].size()}")
+            for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
+                zip(
+                    example["image_keys"],
+                    example["captions"],
+                    example["loss_weights"],
+                    example["input_ids"],
+                    example["original_sizes_hw"],
+                    example["crop_top_lefts"],
+                    example["target_sizes_hw"],
+                    example["flippeds"],
+                )
+            ):
+                print(
+                    f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
+                )
+
+                if show_input_ids:
+                    print(f"input ids: {iid}")
+                    if "input_ids2" in example:
+                        print(f"input ids2: {example['input_ids2'][j]}")
+                if example["images"] is not None:
+                    im = example["images"][j]
+                    print(f"image size: {im.size()}")
+                    im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
+                    im = np.transpose(im, (1, 2, 0))  # c,H,W -> H,W,c
+                    im = im[:, :, ::-1]  # RGB -> BGR (OpenCV)
+
+                    if "conditioning_images" in example:
+                        cond_img = example["conditioning_images"][j]
+                        print(f"conditioning image size: {cond_img.size()}")
+                        cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8)
+                        cond_img = np.transpose(cond_img, (1, 2, 0))
+                        cond_img = cond_img[:, :, ::-1]
+                        if os.name == "nt":
+                            cv2.imshow("cond_img", cond_img)
+
+                    if os.name == "nt":  # only windows
+                        cv2.imshow("img", im)
+                        k = cv2.waitKey()
+                        cv2.destroyAllWindows()
+                    if k == 27 or k == ord("s") or k == ord("e"):
+                        break
+            steps += 1
+
+            if k == ord("e"):
+                break
+            if k == 27 or (example["images"] is None and i >= 8):
+                k = 27
+                break
+        if k == 27:
+            break
+
+        epoch += 1
+
+
+def glob_images(directory, base="*"):
+    img_paths = []
+    for ext in IMAGE_EXTENSIONS:
+        if base == "*":
+            img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
+        else:
+            img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
+    img_paths = list(set(img_paths))  # 重複を排除
+    img_paths.sort()
+    return img_paths
+
+
+def glob_images_pathlib(dir_path, recursive):
+    image_paths = []
+    if recursive:
+        for ext in IMAGE_EXTENSIONS:
+            image_paths += list(dir_path.rglob("*" + ext))
+    else:
+        for ext in IMAGE_EXTENSIONS:
+            image_paths += list(dir_path.glob("*" + ext))
+    image_paths = list(set(image_paths))  # 重複を排除
+    image_paths.sort()
+    return image_paths
+
+
+class MinimalDataset(BaseDataset):
+    def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
+        super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
+
+        self.num_train_images = 0  # update in subclass
+        self.num_reg_images = 0  # update in subclass
+        self.datasets = [self]
+        self.batch_size = 1  # update in subclass
+
+        self.subsets = [self]
+        self.num_repeats = 1  # update in subclass if needed
+        self.img_count = 1  # update in subclass if needed
+        self.bucket_info = {}
+        self.is_reg = False
+        self.image_dir = "dummy"  # for metadata
+
+    def verify_bucket_reso_steps(self, min_steps: int):
+        pass
+
+    def is_latent_cacheable(self) -> bool:
+        return False
+
+    def __len__(self):
+        raise NotImplementedError
+
+    # override to avoid shuffling buckets
+    def set_current_epoch(self, epoch):
+        self.current_epoch = epoch
+
+    def __getitem__(self, idx):
+        r"""
+        The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
+
+        Returns: example like this:
+
+            for i in range(batch_size):
+                image_key = ...  # whatever hashable
+                image_keys.append(image_key)
+
+                image = ...  # PIL Image
+                img_tensor = self.image_transforms(img)
+                images.append(img_tensor)
+
+                caption = ...  # str
+                input_ids = self.get_input_ids(caption)
+                input_ids_list.append(input_ids)
+
+                captions.append(caption)
+
+            images = torch.stack(images, dim=0)
+            input_ids_list = torch.stack(input_ids_list, dim=0)
+            example = {
+                "images": images,
+                "input_ids": input_ids_list,
+                "captions": captions,   # for debug_dataset
+                "latents": None,
+                "image_keys": image_keys,   # for debug_dataset
+                "loss_weights": torch.ones(batch_size, dtype=torch.float32),
+            }
+            return example
+        """
+        raise NotImplementedError
+
+
+def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
+    module = ".".join(args.dataset_class.split(".")[:-1])
+    dataset_class = args.dataset_class.split(".")[-1]
+    module = importlib.import_module(module)
+    dataset_class = getattr(module, dataset_class)
+    train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
+    return train_dataset_group
+
+
+def load_image(image_path):
+    image = Image.open(image_path)
+    if not image.mode == "RGB":
+        image = image.convert("RGB")
+    img = np.array(image, np.uint8)
+    return img
+
+
+# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
+def trim_and_resize_if_required(
+    random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
+) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
+    image_height, image_width = image.shape[0:2]
+    original_size = (image_width, image_height)  # size before resize
+
+    if image_width != resized_size[0] or image_height != resized_size[1]:
+        # リサイズする
+        image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
+
+    image_height, image_width = image.shape[0:2]
+
+    if image_width > reso[0]:
+        trim_size = image_width - reso[0]
+        p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
+        # print("w", trim_size, p)
+        image = image[:, p : p + reso[0]]
+    if image_height > reso[1]:
+        trim_size = image_height - reso[1]
+        p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
+        # print("h", trim_size, p)
+        image = image[p : p + reso[1]]
+
+    # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない
+    # I have no idea how to reflect the cropped value in crop left/top in the case of random crop
+
+    crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size)
+
+    assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
+    return image, original_size, crop_ltrb
+
+
+def cache_batch_latents(
+    vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
+) -> None:
+    r"""
+    requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
+    optionally requires image_infos to have: image
+    if cache_to_disk is True, set info.latents_npz
+        flipped latents is also saved if flip_aug is True
+    if cache_to_disk is False, set info.latents
+        latents_flipped is also set if flip_aug is True
+    latents_original_size and latents_crop_ltrb are also set
+    """
+    images = []
+    for info in image_infos:
+        image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
+        # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
+        image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
+        image = IMAGE_TRANSFORMS(image)
+        images.append(image)
+
+        info.latents_original_size = original_size
+        info.latents_crop_ltrb = crop_ltrb
+
+    img_tensors = torch.stack(images, dim=0)
+    img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
+
+    with torch.no_grad():
+        latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
+
+    if flip_aug:
+        img_tensors = torch.flip(img_tensors, dims=[3])
+        with torch.no_grad():
+            flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
+    else:
+        flipped_latents = [None] * len(latents)
+
+    for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
+        # check NaN
+        if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
+            raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
+
+        if cache_to_disk:
+            save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
+        else:
+            info.latents = latent
+            if flip_aug:
+                info.latents_flipped = flipped_latent
+
+    # FIXME this slows down caching a lot, specify this as an option
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+
+
+def cache_batch_text_encoder_outputs(
+    image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
+):
+    input_ids1 = input_ids1.to(text_encoders[0].device)
+    input_ids2 = input_ids2.to(text_encoders[1].device)
+
+    with torch.no_grad():
+        b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
+            max_token_length,
+            input_ids1,
+            input_ids2,
+            tokenizers[0],
+            tokenizers[1],
+            text_encoders[0],
+            text_encoders[1],
+            dtype,
+        )
+
+        # ここでcpuに移動しておかないと、上書きされてしまう
+        b_hidden_state1 = b_hidden_state1.detach().to("cpu")  # b,n*75+2,768
+        b_hidden_state2 = b_hidden_state2.detach().to("cpu")  # b,n*75+2,1280
+        b_pool2 = b_pool2.detach().to("cpu")  # b,1280
+
+    for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
+        if cache_to_disk:
+            save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2)
+        else:
+            info.text_encoder_outputs1 = hidden_state1
+            info.text_encoder_outputs2 = hidden_state2
+            info.text_encoder_pool2 = pool2
+
+
+def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
+    np.savez(
+        npz_path,
+        hidden_state1=hidden_state1.cpu().float().numpy(),
+        hidden_state2=hidden_state2.cpu().float().numpy(),
+        pool2=pool2.cpu().float().numpy(),
+    )
+
+
+def load_text_encoder_outputs_from_disk(npz_path):
+    with np.load(npz_path) as f:
+        hidden_state1 = torch.from_numpy(f["hidden_state1"])
+        hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
+        pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
+    return hidden_state1, hidden_state2, pool2
+
+
+# endregion
+
+# region モジュール入れ替え部
+"""
+高速化のためのモジュール入れ替え
+"""
+
+# FlashAttentionを使うCrossAttention
+# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
+# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
+
+# constants
+
+EPSILON = 1e-6
+
+# helper functions
+
+
+def exists(val):
+    return val is not None
+
+
+def default(val, d):
+    return val if exists(val) else d
+
+
+def model_hash(filename):
+    """Old model hash used by stable-diffusion-webui"""
+    try:
+        with open(filename, "rb") as file:
+            m = hashlib.sha256()
+
+            file.seek(0x100000)
+            m.update(file.read(0x10000))
+            return m.hexdigest()[0:8]
+    except FileNotFoundError:
+        return "NOFILE"
+    except IsADirectoryError:  # Linux?
+        return "IsADirectory"
+    except PermissionError:  # Windows
+        return "IsADirectory"
+
+
+def calculate_sha256(filename):
+    """New model hash used by stable-diffusion-webui"""
+    try:
+        hash_sha256 = hashlib.sha256()
+        blksize = 1024 * 1024
+
+        with open(filename, "rb") as f:
+            for chunk in iter(lambda: f.read(blksize), b""):
+                hash_sha256.update(chunk)
+
+        return hash_sha256.hexdigest()
+    except FileNotFoundError:
+        return "NOFILE"
+    except IsADirectoryError:  # Linux?
+        return "IsADirectory"
+    except PermissionError:  # Windows
+        return "IsADirectory"
+
+
+def precalculate_safetensors_hashes(tensors, metadata):
+    """Precalculate the model hashes needed by sd-webui-additional-networks to
+    save time on indexing the model later."""
+
+    # Because writing user metadata to the file can change the result of
+    # sd_models.model_hash(), only retain the training metadata for purposes of
+    # calculating the hash, as they are meant to be immutable
+    metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
+
+    bytes = safetensors.torch.save(tensors, metadata)
+    b = BytesIO(bytes)
+
+    model_hash = addnet_hash_safetensors(b)
+    legacy_hash = addnet_hash_legacy(b)
+    return model_hash, legacy_hash
+
+
+def addnet_hash_legacy(b):
+    """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+    m = hashlib.sha256()
+
+    b.seek(0x100000)
+    m.update(b.read(0x10000))
+    return m.hexdigest()[0:8]
+
+
+def addnet_hash_safetensors(b):
+    """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+    hash_sha256 = hashlib.sha256()
+    blksize = 1024 * 1024
+
+    b.seek(0)
+    header = b.read(8)
+    n = int.from_bytes(header, "little")
+
+    offset = n + 8
+    b.seek(offset)
+    for chunk in iter(lambda: b.read(blksize), b""):
+        hash_sha256.update(chunk)
+
+    return hash_sha256.hexdigest()
+
+
+def get_git_revision_hash() -> str:
+    try:
+        return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip()
+    except:
+        return "(unknown)"
+
+
+# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
+#     replace_attentions_for_hypernetwork()
+#     # unet is not used currently, but it is here for future use
+#     unet.enable_xformers_memory_efficient_attention()
+#     return
+#     if mem_eff_attn:
+#         unet.set_attn_processor(FlashAttnProcessor())
+#     elif xformers:
+#         unet.enable_xformers_memory_efficient_attention()
+
+
+# def replace_unet_cross_attn_to_xformers():
+#     print("CrossAttention.forward has been replaced to enable xformers.")
+#     try:
+#         import xformers.ops
+#     except ImportError:
+#         raise ImportError("No xformers / xformersがインストールされていないようです")
+
+#     def forward_xformers(self, x, context=None, mask=None):
+#         h = self.heads
+#         q_in = self.to_q(x)
+
+#         context = default(context, x)
+#         context = context.to(x.dtype)
+
+#         if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
+#             context_k, context_v = self.hypernetwork.forward(x, context)
+#             context_k = context_k.to(x.dtype)
+#             context_v = context_v.to(x.dtype)
+#         else:
+#             context_k = context
+#             context_v = context
+
+#         k_in = self.to_k(context_k)
+#         v_in = self.to_v(context_v)
+
+#         q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
+#         del q_in, k_in, v_in
+
+#         q = q.contiguous()
+#         k = k.contiguous()
+#         v = v.contiguous()
+#         out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)  # 最適なのを選んでくれる
+
+#         out = rearrange(out, "b n h d -> b n (h d)", h=h)
+
+#         # diffusers 0.7.0~
+#         out = self.to_out[0](out)
+#         out = self.to_out[1](out)
+#         return out
+
+
+#     diffusers.models.attention.CrossAttention.forward = forward_xformers
+def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
+    if mem_eff_attn:
+        print("Enable memory efficient attention for U-Net")
+        unet.set_use_memory_efficient_attention(False, True)
+    elif xformers:
+        print("Enable xformers for U-Net")
+        try:
+            import xformers.ops
+        except ImportError:
+            raise ImportError("No xformers / xformersがインストールされていないようです")
+
+        unet.set_use_memory_efficient_attention(True, False)
+    elif sdpa:
+        print("Enable SDPA for U-Net")
+        unet.set_use_sdpa(True)
+
+
+"""
+def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
+    # vae is not used currently, but it is here for future use
+    if mem_eff_attn:
+        replace_vae_attn_to_memory_efficient()
+    elif xformers:
+        # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
+        print("Use Diffusers xformers for VAE")
+        vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
+        vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
+
+
+def replace_vae_attn_to_memory_efficient():
+    print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
+    flash_func = FlashAttentionFunction
+
+    def forward_flash_attn(self, hidden_states):
+        print("forward_flash_attn")
+        q_bucket_size = 512
+        k_bucket_size = 1024
+
+        residual = hidden_states
+        batch, channel, height, width = hidden_states.shape
+
+        # norm
+        hidden_states = self.group_norm(hidden_states)
+
+        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+        # proj to q, k, v
+        query_proj = self.query(hidden_states)
+        key_proj = self.key(hidden_states)
+        value_proj = self.value(hidden_states)
+
+        query_proj, key_proj, value_proj = map(
+            lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
+        )
+
+        out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
+
+        out = rearrange(out, "b h n d -> b n (h d)")
+
+        # compute next hidden_states
+        hidden_states = self.proj_attn(hidden_states)
+        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+        # res connect and rescale
+        hidden_states = (hidden_states + residual) / self.rescale_output_factor
+        return hidden_states
+
+    diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
+"""
+
+
+# endregion
+
+
+# region arguments
+
+
+def load_metadata_from_safetensors(safetensors_file: str) -> dict:
+    """r
+    This method locks the file. see https://github.com/huggingface/safetensors/issues/164
+    If the file isn't .safetensors or doesn't have metadata, return empty dict.
+    """
+    if os.path.splitext(safetensors_file)[1] != ".safetensors":
+        return {}
+
+    with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f:
+        metadata = f.metadata()
+    if metadata is None:
+        metadata = {}
+    return metadata
+
+
+# this metadata is referred from train_network and various scripts, so we wrote here
+SS_METADATA_KEY_V2 = "ss_v2"
+SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
+SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
+SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
+SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
+SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
+
+SS_METADATA_MINIMUM_KEYS = [
+    SS_METADATA_KEY_V2,
+    SS_METADATA_KEY_BASE_MODEL_VERSION,
+    SS_METADATA_KEY_NETWORK_MODULE,
+    SS_METADATA_KEY_NETWORK_DIM,
+    SS_METADATA_KEY_NETWORK_ALPHA,
+    SS_METADATA_KEY_NETWORK_ARGS,
+]
+
+
+def build_minimum_network_metadata(
+    v2: Optional[bool],
+    base_model: Optional[str],
+    network_module: str,
+    network_dim: str,
+    network_alpha: str,
+    network_args: Optional[dict],
+):
+    # old LoRA doesn't have base_model
+    metadata = {
+        SS_METADATA_KEY_NETWORK_MODULE: network_module,
+        SS_METADATA_KEY_NETWORK_DIM: network_dim,
+        SS_METADATA_KEY_NETWORK_ALPHA: network_alpha,
+    }
+    if v2 is not None:
+        metadata[SS_METADATA_KEY_V2] = v2
+    if base_model is not None:
+        metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model
+    if network_args is not None:
+        metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args)
+    return metadata
+
+
+def get_sai_model_spec(
+    state_dict: dict,
+    args: argparse.Namespace,
+    sdxl: bool,
+    lora: bool,
+    textual_inversion: bool,
+    is_stable_diffusion_ckpt: Optional[bool] = None,  # None for TI and LoRA
+):
+    timestamp = time.time()
+
+    v2 = args.v2
+    v_parameterization = args.v_parameterization
+    reso = args.resolution
+
+    title = args.metadata_title if args.metadata_title is not None else args.output_name
+
+    if args.min_timestep is not None or args.max_timestep is not None:
+        min_time_step = args.min_timestep if args.min_timestep is not None else 0
+        max_time_step = args.max_timestep if args.max_timestep is not None else 1000
+        timesteps = (min_time_step, max_time_step)
+    else:
+        timesteps = None
+
+    metadata = sai_model_spec.build_metadata(
+        state_dict,
+        v2,
+        v_parameterization,
+        sdxl,
+        lora,
+        textual_inversion,
+        timestamp,
+        title=title,
+        reso=reso,
+        is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
+        author=args.metadata_author,
+        description=args.metadata_description,
+        license=args.metadata_license,
+        tags=args.metadata_tags,
+        timesteps=timesteps,
+        clip_skip=args.clip_skip,  # None or int
+    )
+    return metadata
+
+
+def add_sd_models_arguments(parser: argparse.ArgumentParser):
+    # for pretrained models
+    parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
+    parser.add_argument(
+        "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
+    )
+    parser.add_argument(
+        "--pretrained_model_name_or_path",
+        type=str,
+        default=None,
+        help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
+    )
+    parser.add_argument(
+        "--tokenizer_cache_dir",
+        type=str,
+        default=None,
+        help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
+    )
+
+
+def add_optimizer_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--optimizer_type",
+        type=str,
+        default="",
+        help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
+    )
+
+    # backward compatibility
+    parser.add_argument(
+        "--use_8bit_adam",
+        action="store_true",
+        help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)",
+    )
+    parser.add_argument(
+        "--use_lion_optimizer",
+        action="store_true",
+        help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)",
+    )
+
+    parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
+    parser.add_argument(
+        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない"
+    )
+
+    parser.add_argument(
+        "--optimizer_args",
+        type=str,
+        default=None,
+        nargs="*",
+        help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
+    )
+
+    parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
+    parser.add_argument(
+        "--lr_scheduler_args",
+        type=str,
+        default=None,
+        nargs="*",
+        help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
+    )
+
+    parser.add_argument(
+        "--lr_scheduler",
+        type=str,
+        default="constant",
+        help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
+    )
+    parser.add_argument(
+        "--lr_warmup_steps",
+        type=int,
+        default=0,
+        help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
+    )
+    parser.add_argument(
+        "--lr_scheduler_num_cycles",
+        type=int,
+        default=1,
+        help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
+    )
+    parser.add_argument(
+        "--lr_scheduler_power",
+        type=float,
+        default=1,
+        help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
+    )
+
+
+def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
+    parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
+    parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
+    parser.add_argument(
+        "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
+    )
+    parser.add_argument(
+        "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
+    )
+    parser.add_argument(
+        "--huggingface_path_in_repo",
+        type=str,
+        default=None,
+        help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
+    )
+    parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
+    parser.add_argument(
+        "--huggingface_repo_visibility",
+        type=str,
+        default=None,
+        help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
+    )
+    parser.add_argument(
+        "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
+    )
+    parser.add_argument(
+        "--resume_from_huggingface",
+        action="store_true",
+        help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
+    )
+    parser.add_argument(
+        "--async_upload",
+        action="store_true",
+        help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
+    )
+    parser.add_argument(
+        "--save_precision",
+        type=str,
+        default=None,
+        choices=[None, "float", "fp16", "bf16"],
+        help="precision in saving / 保存時に精度を変更して保存する",
+    )
+    parser.add_argument(
+        "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する"
+    )
+    parser.add_argument(
+        "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
+    )
+    parser.add_argument(
+        "--save_n_epoch_ratio",
+        type=int,
+        default=None,
+        help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)",
+    )
+    parser.add_argument(
+        "--save_last_n_epochs",
+        type=int,
+        default=None,
+        help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
+    )
+    parser.add_argument(
+        "--save_last_n_epochs_state",
+        type=int,
+        default=None,
+        help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
+    )
+    parser.add_argument(
+        "--save_last_n_steps",
+        type=int,
+        default=None,
+        help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
+    )
+    parser.add_argument(
+        "--save_last_n_steps_state",
+        type=int,
+        default=None,
+        help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
+    )
+    parser.add_argument(
+        "--save_state",
+        action="store_true",
+        help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
+    )
+    parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
+
+    parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
+    parser.add_argument(
+        "--max_token_length",
+        type=int,
+        default=None,
+        choices=[None, 150, 225],
+        help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)",
+    )
+    parser.add_argument(
+        "--mem_eff_attn",
+        action="store_true",
+        help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
+    )
+    parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
+    parser.add_argument(
+        "--sdpa",
+        action="store_true",
+        help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
+    )
+    parser.add_argument(
+        "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
+    )
+
+    parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
+    parser.add_argument(
+        "--max_train_epochs",
+        type=int,
+        default=None,
+        help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
+    )
+    parser.add_argument(
+        "--max_data_loader_n_workers",
+        type=int,
+        default=8,
+        help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
+    )
+    parser.add_argument(
+        "--persistent_data_loader_workers",
+        action="store_true",
+        help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
+    )
+    parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
+    parser.add_argument(
+        "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
+    )
+    parser.add_argument(
+        "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
+    )
+    parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
+    parser.add_argument(
+        "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
+    )  # TODO move to SDXL training, because it is not supported by SD1/2
+    parser.add_argument(
+        "--ddp_timeout",
+        type=int,
+        default=None,
+        help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
+    )
+    parser.add_argument(
+        "--ddp_gradient_as_bucket_view",
+        action="store_true",
+        help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
+    )
+    parser.add_argument(
+        "--ddp_static_graph",
+        action="store_true",
+        help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
+    )
+    parser.add_argument(
+        "--clip_skip",
+        type=int,
+        default=None,
+        help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)",
+    )
+    parser.add_argument(
+        "--logging_dir",
+        type=str,
+        default=None,
+        help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
+    )
+    parser.add_argument(
+        "--log_with",
+        type=str,
+        default=None,
+        choices=["tensorboard", "wandb", "all"],
+        help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
+    )
+    parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
+    parser.add_argument(
+        "--log_tracker_name",
+        type=str,
+        default=None,
+        help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
+    )
+    parser.add_argument(
+        "--log_tracker_config",
+        type=str,
+        default=None,
+        help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
+    )
+    parser.add_argument(
+        "--wandb_api_key",
+        type=str,
+        default=None,
+        help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
+    )
+    parser.add_argument(
+        "--noise_offset",
+        type=float,
+        default=None,
+        help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
+    )
+    parser.add_argument(
+        "--multires_noise_iterations",
+        type=int,
+        default=None,
+        help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
+    )
+    parser.add_argument(
+        "--ip_noise_gamma",
+        type=float,
+        default=None,
+        help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) "
+        + "/  input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)",
+    )
+    # parser.add_argument(
+    #     "--perlin_noise",
+    #     type=int,
+    #     default=None,
+    #     help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
+    # )
+    parser.add_argument(
+        "--multires_noise_discount",
+        type=float,
+        default=0.3,
+        help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)",
+    )
+    parser.add_argument(
+        "--adaptive_noise_scale",
+        type=float,
+        default=None,
+        help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)",
+    )
+    parser.add_argument(
+        "--zero_terminal_snr",
+        action="store_true",
+        help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
+    )
+    parser.add_argument(
+        "--min_timestep",
+        type=int,
+        default=None,
+        help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
+    )
+    parser.add_argument(
+        "--max_timestep",
+        type=int,
+        default=None,
+        help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
+    )
+
+    parser.add_argument(
+        "--lowram",
+        action="store_true",
+        help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)",
+    )
+
+    parser.add_argument(
+        "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
+    )
+    parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する")
+    parser.add_argument(
+        "--sample_every_n_epochs",
+        type=int,
+        default=None,
+        help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
+    )
+    parser.add_argument(
+        "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル"
+    )
+    parser.add_argument(
+        "--sample_sampler",
+        type=str,
+        default="ddim",
+        choices=[
+            "ddim",
+            "pndm",
+            "lms",
+            "euler",
+            "euler_a",
+            "heun",
+            "dpm_2",
+            "dpm_2_a",
+            "dpmsolver",
+            "dpmsolver++",
+            "dpmsingle",
+            "k_lms",
+            "k_euler",
+            "k_euler_a",
+            "k_dpm_2",
+            "k_dpm_2_a",
+        ],
+        help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
+    )
+
+    parser.add_argument(
+        "--config_file",
+        type=str,
+        default=None,
+        help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
+    )
+    parser.add_argument(
+        "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
+    )
+
+    # SAI Model spec
+    parser.add_argument(
+        "--metadata_title",
+        type=str,
+        default=None,
+        help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
+    )
+    parser.add_argument(
+        "--metadata_author",
+        type=str,
+        default=None,
+        help="author name for model metadata / メタデータに書き込まれるモデル作者名",
+    )
+    parser.add_argument(
+        "--metadata_description",
+        type=str,
+        default=None,
+        help="description for model metadata / メタデータに書き込まれるモデル説明",
+    )
+    parser.add_argument(
+        "--metadata_license",
+        type=str,
+        default=None,
+        help="license for model metadata / メタデータに書き込まれるモデルライセンス",
+    )
+    parser.add_argument(
+        "--metadata_tags",
+        type=str,
+        default=None,
+        help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
+    )
+
+    if support_dreambooth:
+        # DreamBooth training
+        parser.add_argument(
+            "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
+        )
+
+
+def verify_training_args(args: argparse.Namespace):
+    if args.v_parameterization and not args.v2:
+        print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
+    if args.v2 and args.clip_skip is not None:
+        print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
+
+    if args.cache_latents_to_disk and not args.cache_latents:
+        args.cache_latents = True
+        print(
+            "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
+        )
+
+    # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
+    # # Listを使って数えてもいいけど並べてしまえ
+    # if args.noise_offset is not None and args.multires_noise_iterations is not None:
+    #     raise ValueError(
+    #         "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
+    #     )
+    # if args.noise_offset is not None and args.perlin_noise is not None:
+    #     raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
+    # if args.perlin_noise is not None and args.multires_noise_iterations is not None:
+    #     raise ValueError(
+    #         "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
+    #     )
+
+    if args.adaptive_noise_scale is not None and args.noise_offset is None:
+        raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
+
+    if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
+        raise ValueError(
+            "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
+        )
+
+    if args.v_pred_like_loss and args.v_parameterization:
+        raise ValueError(
+            "v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
+        )
+
+    if args.zero_terminal_snr and not args.v_parameterization:
+        print(
+            f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected"
+            + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
+        )
+
+
+def add_dataset_arguments(
+    parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
+):
+    # dataset common
+    parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
+    parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする")
+    parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字")
+    parser.add_argument(
+        "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
+    )
+    parser.add_argument(
+        "--caption_extention",
+        type=str,
+        default=None,
+        help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)",
+    )
+    parser.add_argument(
+        "--keep_tokens",
+        type=int,
+        default=0,
+        help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
+    )
+    parser.add_argument(
+        "--keep_tokens_separator",
+        type=str,
+        default="",
+        help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+        + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
+    )
+    parser.add_argument(
+        "--caption_prefix",
+        type=str,
+        default=None,
+        help="prefix for caption text / captionのテキストの先頭に付ける文字列",
+    )
+    parser.add_argument(
+        "--caption_suffix",
+        type=str,
+        default=None,
+        help="suffix for caption text / captionのテキストの末尾に付ける文字列",
+    )
+    parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
+    parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
+    parser.add_argument(
+        "--face_crop_aug_range",
+        type=str,
+        default=None,
+        help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)",
+    )
+    parser.add_argument(
+        "--random_crop",
+        action="store_true",
+        help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)",
+    )
+    parser.add_argument(
+        "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)"
+    )
+    parser.add_argument(
+        "--resolution",
+        type=str,
+        default=None,
+        help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)",
+    )
+    parser.add_argument(
+        "--cache_latents",
+        action="store_true",
+        help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ",
+    )
+    parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
+    parser.add_argument(
+        "--cache_latents_to_disk",
+        action="store_true",
+        help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
+    )
+    parser.add_argument(
+        "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
+    )
+    parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
+    parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
+    parser.add_argument(
+        "--bucket_reso_steps",
+        type=int,
+        default=64,
+        help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
+    )
+    parser.add_argument(
+        "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
+    )
+
+    parser.add_argument(
+        "--token_warmup_min",
+        type=int,
+        default=1,
+        help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
+    )
+    parser.add_argument(
+        "--token_warmup_step",
+        type=float,
+        default=0,
+        help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
+    )
+
+    parser.add_argument(
+        "--dataset_class",
+        type=str,
+        default=None,
+        help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
+    )
+
+    if support_caption_dropout:
+        # Textual Inversion はcaptionのdropoutをsupportしない
+        # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
+        parser.add_argument(
+            "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合"
+        )
+        parser.add_argument(
+            "--caption_dropout_every_n_epochs",
+            type=int,
+            default=0,
+            help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする",
+        )
+        parser.add_argument(
+            "--caption_tag_dropout_rate",
+            type=float,
+            default=0.0,
+            help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合",
+        )
+
+    if support_dreambooth:
+        # DreamBooth dataset
+        parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
+
+    if support_caption:
+        # caption dataset
+        parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
+        parser.add_argument(
+            "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数"
+        )
+
+
+def add_sd_saving_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--save_model_as",
+        type=str,
+        default=None,
+        choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
+        help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)",
+    )
+    parser.add_argument(
+        "--use_safetensors",
+        action="store_true",
+        help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)",
+    )
+
+
+def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
+    if not args.config_file:
+        return args
+
+    config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
+
+    if args.output_config:
+        # check if config file exists
+        if os.path.exists(config_path):
+            print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}")
+            exit(1)
+
+        # convert args to dictionary
+        args_dict = vars(args)
+
+        # remove unnecessary keys
+        for key in ["config_file", "output_config", "wandb_api_key"]:
+            if key in args_dict:
+                del args_dict[key]
+
+        # get default args from parser
+        default_args = vars(parser.parse_args([]))
+
+        # remove default values: cannot use args_dict.items directly because it will be changed during iteration
+        for key, value in list(args_dict.items()):
+            if key in default_args and value == default_args[key]:
+                del args_dict[key]
+
+        # convert Path to str in dictionary
+        for key, value in args_dict.items():
+            if isinstance(value, pathlib.Path):
+                args_dict[key] = str(value)
+
+        # convert to toml and output to file
+        with open(config_path, "w") as f:
+            toml.dump(args_dict, f)
+
+        print(f"Saved config file / 設定ファイルを保存しました: {config_path}")
+        exit(0)
+
+    if not os.path.exists(config_path):
+        print(f"{config_path} not found.")
+        exit(1)
+
+    print(f"Loading settings from {config_path}...")
+    with open(config_path, "r") as f:
+        config_dict = toml.load(f)
+
+    # combine all sections into one
+    ignore_nesting_dict = {}
+    for section_name, section_dict in config_dict.items():
+        # if value is not dict, save key and value as is
+        if not isinstance(section_dict, dict):
+            ignore_nesting_dict[section_name] = section_dict
+            continue
+
+        # if value is dict, save all key and value into one dict
+        for key, value in section_dict.items():
+            ignore_nesting_dict[key] = value
+
+    config_args = argparse.Namespace(**ignore_nesting_dict)
+    args = parser.parse_args(namespace=config_args)
+    args.config_file = os.path.splitext(args.config_file)[0]
+    print(args.config_file)
+
+    return args
+
+
+# endregion
+
+# region utils
+
+
+def resume_from_local_or_hf_if_specified(accelerator, args):
+    if not args.resume:
+        return
+
+    if not args.resume_from_huggingface:
+        print(f"resume training from local state: {args.resume}")
+        accelerator.load_state(args.resume)
+        return
+
+    print(f"resume training from huggingface state: {args.resume}")
+    repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
+    path_in_repo = "/".join(args.resume.split("/")[2:])
+    revision = None
+    repo_type = None
+    if ":" in path_in_repo:
+        divided = path_in_repo.split(":")
+        if len(divided) == 2:
+            path_in_repo, revision = divided
+            repo_type = "model"
+        else:
+            path_in_repo, revision, repo_type = divided
+    print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
+
+    list_files = huggingface_util.list_dir(
+        repo_id=repo_id,
+        subfolder=path_in_repo,
+        revision=revision,
+        token=args.huggingface_token,
+        repo_type=repo_type,
+    )
+
+    async def download(filename) -> str:
+        def task():
+            return hf_hub_download(
+                repo_id=repo_id,
+                filename=filename,
+                revision=revision,
+                repo_type=repo_type,
+                token=args.huggingface_token,
+            )
+
+        return await asyncio.get_event_loop().run_in_executor(None, task)
+
+    loop = asyncio.get_event_loop()
+    results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
+    if len(results) == 0:
+        raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
+    dirname = os.path.dirname(results[0])
+    accelerator.load_state(dirname)
+
+
+def get_optimizer(args, trainable_params):
+    # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
+
+    optimizer_type = args.optimizer_type
+    if args.use_8bit_adam:
+        assert (
+            not args.use_lion_optimizer
+        ), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
+        assert (
+            optimizer_type is None or optimizer_type == ""
+        ), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
+        optimizer_type = "AdamW8bit"
+
+    elif args.use_lion_optimizer:
+        assert (
+            optimizer_type is None or optimizer_type == ""
+        ), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
+        optimizer_type = "Lion"
+
+    if optimizer_type is None or optimizer_type == "":
+        optimizer_type = "AdamW"
+    optimizer_type = optimizer_type.lower()
+
+    # 引数を分解する
+    optimizer_kwargs = {}
+    if args.optimizer_args is not None and len(args.optimizer_args) > 0:
+        for arg in args.optimizer_args:
+            key, value = arg.split("=")
+            value = ast.literal_eval(value)
+
+            # value = value.split(",")
+            # for i in range(len(value)):
+            #     if value[i].lower() == "true" or value[i].lower() == "false":
+            #         value[i] = value[i].lower() == "true"
+            #     else:
+            #         value[i] = ast.float(value[i])
+            # if len(value) == 1:
+            #     value = value[0]
+            # else:
+            #     value = tuple(value)
+
+            optimizer_kwargs[key] = value
+    # print("optkwargs:", optimizer_kwargs)
+
+    lr = args.learning_rate
+    optimizer = None
+
+    if optimizer_type == "Lion".lower():
+        try:
+            import lion_pytorch
+        except ImportError:
+            raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
+        print(f"use Lion optimizer | {optimizer_kwargs}")
+        optimizer_class = lion_pytorch.Lion
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    elif optimizer_type.endswith("8bit".lower()):
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
+
+        if optimizer_type == "AdamW8bit".lower():
+            print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
+            optimizer_class = bnb.optim.AdamW8bit
+            optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+        elif optimizer_type == "SGDNesterov8bit".lower():
+            print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
+            if "momentum" not in optimizer_kwargs:
+                print(
+                    f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
+                )
+                optimizer_kwargs["momentum"] = 0.9
+
+            optimizer_class = bnb.optim.SGD8bit
+            optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
+
+        elif optimizer_type == "Lion8bit".lower():
+            print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
+            try:
+                optimizer_class = bnb.optim.Lion8bit
+            except AttributeError:
+                raise AttributeError(
+                    "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
+                )
+        elif optimizer_type == "PagedAdamW8bit".lower():
+            print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
+            try:
+                optimizer_class = bnb.optim.PagedAdamW8bit
+            except AttributeError:
+                raise AttributeError(
+                    "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
+                )
+        elif optimizer_type == "PagedLion8bit".lower():
+            print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
+            try:
+                optimizer_class = bnb.optim.PagedLion8bit
+            except AttributeError:
+                raise AttributeError(
+                    "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
+                )
+
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    elif optimizer_type == "PagedAdamW".lower():
+        print(f"use PagedAdamW optimizer | {optimizer_kwargs}")
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
+        try:
+            optimizer_class = bnb.optim.PagedAdamW
+        except AttributeError:
+            raise AttributeError(
+                "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
+            )
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    elif optimizer_type == "PagedAdamW32bit".lower():
+        print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}")
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
+        try:
+            optimizer_class = bnb.optim.PagedAdamW32bit
+        except AttributeError:
+            raise AttributeError(
+                "No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
+            )
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    elif optimizer_type == "SGDNesterov".lower():
+        print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
+        if "momentum" not in optimizer_kwargs:
+            print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
+            optimizer_kwargs["momentum"] = 0.9
+
+        optimizer_class = torch.optim.SGD
+        optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
+
+    elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
+        # check lr and lr_count, and print warning
+        actual_lr = lr
+        lr_count = 1
+        if type(trainable_params) == list and type(trainable_params[0]) == dict:
+            lrs = set()
+            actual_lr = trainable_params[0].get("lr", actual_lr)
+            for group in trainable_params:
+                lrs.add(group.get("lr", actual_lr))
+            lr_count = len(lrs)
+
+        if actual_lr <= 0.1:
+            print(
+                f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
+            )
+            print("recommend option: lr=1.0 / 推奨は1.0です")
+        if lr_count > 1:
+            print(
+                f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
+            )
+
+        if optimizer_type.startswith("DAdapt".lower()):
+            # DAdaptation family
+            # check dadaptation is installed
+            try:
+                import dadaptation
+                import dadaptation.experimental as experimental
+            except ImportError:
+                raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
+
+            # set optimizer
+            if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
+                optimizer_class = experimental.DAdaptAdamPreprint
+                print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
+            elif optimizer_type == "DAdaptAdaGrad".lower():
+                optimizer_class = dadaptation.DAdaptAdaGrad
+                print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
+            elif optimizer_type == "DAdaptAdam".lower():
+                optimizer_class = dadaptation.DAdaptAdam
+                print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
+            elif optimizer_type == "DAdaptAdan".lower():
+                optimizer_class = dadaptation.DAdaptAdan
+                print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
+            elif optimizer_type == "DAdaptAdanIP".lower():
+                optimizer_class = experimental.DAdaptAdanIP
+                print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
+            elif optimizer_type == "DAdaptLion".lower():
+                optimizer_class = dadaptation.DAdaptLion
+                print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
+            elif optimizer_type == "DAdaptSGD".lower():
+                optimizer_class = dadaptation.DAdaptSGD
+                print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
+            else:
+                raise ValueError(f"Unknown optimizer type: {optimizer_type}")
+
+            optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+        else:
+            # Prodigy
+            # check Prodigy is installed
+            try:
+                import prodigyopt
+            except ImportError:
+                raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
+
+            print(f"use Prodigy optimizer | {optimizer_kwargs}")
+            optimizer_class = prodigyopt.Prodigy
+            optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    elif optimizer_type == "Adafactor".lower():
+        # 引数を確認して適宜補正する
+        if "relative_step" not in optimizer_kwargs:
+            optimizer_kwargs["relative_step"] = True  # default
+        if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
+            print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
+            optimizer_kwargs["relative_step"] = True
+        print(f"use Adafactor optimizer | {optimizer_kwargs}")
+
+        if optimizer_kwargs["relative_step"]:
+            print(f"relative_step is true / relative_stepがtrueです")
+            if lr != 0.0:
+                print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
+            args.learning_rate = None
+
+            # trainable_paramsがgroupだった時の処理:lrを削除する
+            if type(trainable_params) == list and type(trainable_params[0]) == dict:
+                has_group_lr = False
+                for group in trainable_params:
+                    p = group.pop("lr", None)
+                    has_group_lr = has_group_lr or (p is not None)
+
+                if has_group_lr:
+                    # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
+                    print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
+                    args.unet_lr = None
+                    args.text_encoder_lr = None
+
+            if args.lr_scheduler != "adafactor":
+                print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
+            args.lr_scheduler = f"adafactor:{lr}"  # ちょっと微妙だけど
+
+            lr = None
+        else:
+            if args.max_grad_norm != 0.0:
+                print(
+                    f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
+                )
+            if args.lr_scheduler != "constant_with_warmup":
+                print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
+            if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
+                print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
+
+        optimizer_class = transformers.optimization.Adafactor
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    elif optimizer_type == "AdamW".lower():
+        print(f"use AdamW optimizer | {optimizer_kwargs}")
+        optimizer_class = torch.optim.AdamW
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    if optimizer is None:
+        # 任意のoptimizerを使う
+        optimizer_type = args.optimizer_type  # lowerでないやつ(微妙)
+        print(f"use {optimizer_type} | {optimizer_kwargs}")
+        if "." not in optimizer_type:
+            optimizer_module = torch.optim
+        else:
+            values = optimizer_type.split(".")
+            optimizer_module = importlib.import_module(".".join(values[:-1]))
+            optimizer_type = values[-1]
+
+        optimizer_class = getattr(optimizer_module, optimizer_type)
+        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+    optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
+    optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
+
+    return optimizer_name, optimizer_args, optimizer
+
+
+# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
+# Add some checking and features to the original function.
+
+
+def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
+    """
+    Unified API to get any scheduler from its name.
+    """
+    name = args.lr_scheduler
+    num_warmup_steps: Optional[int] = args.lr_warmup_steps
+    num_training_steps = args.max_train_steps * num_processes  # * args.gradient_accumulation_steps
+    num_cycles = args.lr_scheduler_num_cycles
+    power = args.lr_scheduler_power
+
+    lr_scheduler_kwargs = {}  # get custom lr_scheduler kwargs
+    if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
+        for arg in args.lr_scheduler_args:
+            key, value = arg.split("=")
+            value = ast.literal_eval(value)
+            lr_scheduler_kwargs[key] = value
+
+    def wrap_check_needless_num_warmup_steps(return_vals):
+        if num_warmup_steps is not None and num_warmup_steps != 0:
+            raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
+        return return_vals
+
+    # using any lr_scheduler from other library
+    if args.lr_scheduler_type:
+        lr_scheduler_type = args.lr_scheduler_type
+        print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
+        if "." not in lr_scheduler_type:  # default to use torch.optim
+            lr_scheduler_module = torch.optim.lr_scheduler
+        else:
+            values = lr_scheduler_type.split(".")
+            lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
+            lr_scheduler_type = values[-1]
+        lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
+        lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
+        return wrap_check_needless_num_warmup_steps(lr_scheduler)
+
+    if name.startswith("adafactor"):
+        assert (
+            type(optimizer) == transformers.optimization.Adafactor
+        ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
+        initial_lr = float(name.split(":")[1])
+        # print("adafactor scheduler init lr", initial_lr)
+        return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
+
+    name = SchedulerType(name)
+    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+    if name == SchedulerType.CONSTANT:
+        return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
+
+    if name == SchedulerType.PIECEWISE_CONSTANT:
+        return schedule_func(optimizer, **lr_scheduler_kwargs)  # step_rules and last_epoch are given as kwargs
+
+    # All other schedulers require `num_warmup_steps`
+    if num_warmup_steps is None:
+        raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+    if name == SchedulerType.CONSTANT_WITH_WARMUP:
+        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
+
+    # All other schedulers require `num_training_steps`
+    if num_training_steps is None:
+        raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+    if name == SchedulerType.COSINE_WITH_RESTARTS:
+        return schedule_func(
+            optimizer,
+            num_warmup_steps=num_warmup_steps,
+            num_training_steps=num_training_steps,
+            num_cycles=num_cycles,
+            **lr_scheduler_kwargs,
+        )
+
+    if name == SchedulerType.POLYNOMIAL:
+        return schedule_func(
+            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
+        )
+
+    return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
+
+
+def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
+    # backward compatibility
+    if args.caption_extention is not None:
+        args.caption_extension = args.caption_extention
+        args.caption_extention = None
+
+    # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
+    if args.resolution is not None:
+        args.resolution = tuple([int(r) for r in args.resolution.split(",")])
+        if len(args.resolution) == 1:
+            args.resolution = (args.resolution[0], args.resolution[0])
+        assert (
+            len(args.resolution) == 2
+        ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
+
+    if args.face_crop_aug_range is not None:
+        args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
+        assert (
+            len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1]
+        ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
+    else:
+        args.face_crop_aug_range = None
+
+    if support_metadata:
+        if args.in_json is not None and (args.color_aug or args.random_crop):
+            print(
+                f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます"
+            )
+
+
+def load_tokenizer(args: argparse.Namespace):
+    print("prepare tokenizer")
+    original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
+
+    tokenizer: CLIPTokenizer = None
+    if args.tokenizer_cache_dir:
+        local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
+        if os.path.exists(local_tokenizer_path):
+            print(f"load tokenizer from cache: {local_tokenizer_path}")
+            tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)  # same for v1 and v2
+
+    if tokenizer is None:
+        if args.v2:
+            tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
+        else:
+            tokenizer = CLIPTokenizer.from_pretrained(original_path)
+
+    if hasattr(args, "max_token_length") and args.max_token_length is not None:
+        print(f"update token length: {args.max_token_length}")
+
+    if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
+        print(f"save Tokenizer to cache: {local_tokenizer_path}")
+        tokenizer.save_pretrained(local_tokenizer_path)
+
+    return tokenizer
+
+
+def prepare_accelerator(args: argparse.Namespace):
+    if args.logging_dir is None:
+        logging_dir = None
+    else:
+        log_prefix = "" if args.log_prefix is None else args.log_prefix
+        logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
+
+    if args.log_with is None:
+        if logging_dir is not None:
+            log_with = "tensorboard"
+        else:
+            log_with = None
+    else:
+        log_with = args.log_with
+        if log_with in ["tensorboard", "all"]:
+            if logging_dir is None:
+                raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
+        if log_with in ["wandb", "all"]:
+            try:
+                import wandb
+            except ImportError:
+                raise ImportError("No wandb / wandb がインストールされていないようです")
+            if logging_dir is not None:
+                os.makedirs(logging_dir, exist_ok=True)
+                os.environ["WANDB_DIR"] = logging_dir
+            if args.wandb_api_key is not None:
+                wandb.login(key=args.wandb_api_key)
+
+    kwargs_handlers = (
+        InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
+        DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
+        if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
+        else None,
+    )
+    kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
+    accelerator = Accelerator(
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision,
+        log_with=log_with,
+        project_dir=logging_dir,
+        kwargs_handlers=kwargs_handlers,
+    )
+    return accelerator
+
+
+def prepare_dtype(args: argparse.Namespace):
+    weight_dtype = torch.float32
+    if args.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+    elif args.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+
+    save_dtype = None
+    if args.save_precision == "fp16":
+        save_dtype = torch.float16
+    elif args.save_precision == "bf16":
+        save_dtype = torch.bfloat16
+    elif args.save_precision == "float":
+        save_dtype = torch.float32
+
+    return weight_dtype, save_dtype
+
+
+def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False):
+    name_or_path = args.pretrained_model_name_or_path
+    name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path
+    load_stable_diffusion_format = os.path.isfile(name_or_path)  # determine SD or Diffusers
+    if load_stable_diffusion_format:
+        print(f"load StableDiffusion checkpoint: {name_or_path}")
+        text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
+            args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
+        )
+    else:
+        # Diffusers model is loaded to CPU
+        print(f"load Diffusers pretrained models: {name_or_path}")
+        try:
+            pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
+        except EnvironmentError as ex:
+            print(
+                f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
+            )
+            raise ex
+        text_encoder = pipe.text_encoder
+        vae = pipe.vae
+        unet = pipe.unet
+        del pipe
+
+        # Diffusers U-Net to original U-Net
+        # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう
+        # print(f"unet config: {unet.config}")
+        original_unet = UNet2DConditionModel(
+            unet.config.sample_size,
+            unet.config.attention_head_dim,
+            unet.config.cross_attention_dim,
+            unet.config.use_linear_projection,
+            unet.config.upcast_attention,
+        )
+        original_unet.load_state_dict(unet.state_dict())
+        unet = original_unet
+        print("U-Net converted to original U-Net")
+
+    # VAEを読み込む
+    if args.vae is not None:
+        vae = model_util.load_vae(args.vae, weight_dtype)
+        print("additional VAE loaded")
+
+    return text_encoder, vae, unet, load_stable_diffusion_format
+
+
+def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
+    # load models for each process
+    for pi in range(accelerator.state.num_processes):
+        if pi == accelerator.state.local_process_index:
+            print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
+
+            text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
+                args,
+                weight_dtype,
+                accelerator.device if args.lowram else "cpu",
+                unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
+            )
+
+            # work on low-ram device
+            if args.lowram:
+                text_encoder.to(accelerator.device)
+                unet.to(accelerator.device)
+                vae.to(accelerator.device)
+
+            gc.collect()
+            torch.cuda.empty_cache()
+        accelerator.wait_for_everyone()
+
+    return text_encoder, vae, unet, load_stable_diffusion_format
+
+
+def patch_accelerator_for_fp16_training(accelerator):
+    org_unscale_grads = accelerator.scaler._unscale_grads_
+
+    def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
+        return org_unscale_grads(optimizer, inv_scale, found_inf, True)
+
+    accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
+
+
+def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
+    # with no_token_padding, the length is not max length, return result immediately
+    if input_ids.size()[-1] != tokenizer.model_max_length:
+        return text_encoder(input_ids)[0]
+
+    # input_ids: b,n,77
+    b_size = input_ids.size()[0]
+    input_ids = input_ids.reshape((-1, tokenizer.model_max_length))  # batch_size*3, 77
+
+    if args.clip_skip is None:
+        encoder_hidden_states = text_encoder(input_ids)[0]
+    else:
+        enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
+        encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
+        encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
+
+    # bs*3, 77, 768 or 1024
+    encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
+
+    if args.max_token_length is not None:
+        if args.v2:
+            # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
+            states_list = [encoder_hidden_states[:, 0].unsqueeze(1)]  # <BOS>
+            for i in range(1, args.max_token_length, tokenizer.model_max_length):
+                chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]  # <BOS> の後から 最後の前まで
+                if i > 0:
+                    for j in range(len(chunk)):
+                        if input_ids[j, 1] == tokenizer.eos_token:  # 空、つまり <BOS> <EOS> <PAD> ...のパターン
+                            chunk[j, 0] = chunk[j, 1]  # 次の <PAD> の値をコピーする
+                states_list.append(chunk)  # <BOS> の後から <EOS> の前まで
+            states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))  # <EOS> か <PAD> のどちらか
+            encoder_hidden_states = torch.cat(states_list, dim=1)
+        else:
+            # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
+            states_list = [encoder_hidden_states[:, 0].unsqueeze(1)]  # <BOS>
+            for i in range(1, args.max_token_length, tokenizer.model_max_length):
+                states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2])  # <BOS> の後から <EOS> の前まで
+            states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))  # <EOS>
+            encoder_hidden_states = torch.cat(states_list, dim=1)
+
+    if weight_dtype is not None:
+        # this is required for additional network training
+        encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
+
+    return encoder_hidden_states
+
+
+def pool_workaround(
+    text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
+):
+    r"""
+    workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
+    instead of the hidden states for the EOS token
+    If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
+
+    Original code from CLIP's pooling function:
+
+    \# text_embeds.shape = [batch_size, sequence_length, transformer.width]
+    \# take features from the eot embedding (eot_token is the highest number in each sequence)
+    \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+    pooled_output = last_hidden_state[
+        torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
+        input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
+    ]
+    """
+
+    # input_ids: b*n,77
+    # find index for EOS token
+
+    # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
+    # eos_token_index = torch.where(input_ids == eos_token_id)[1]
+    # eos_token_index = eos_token_index.to(device=last_hidden_state.device)
+
+    # Create a mask where the EOS tokens are
+    eos_token_mask = (input_ids == eos_token_id).int()
+
+    # Use argmax to find the last index of the EOS token for each element in the batch
+    eos_token_index = torch.argmax(eos_token_mask, dim=1)  # this will be 0 if there is no EOS token, it's fine
+    eos_token_index = eos_token_index.to(device=last_hidden_state.device)
+
+    # get hidden states for EOS token
+    pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
+
+    # apply projection: projection may be of different dtype than last_hidden_state
+    pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
+    pooled_output = pooled_output.to(last_hidden_state.dtype)
+
+    return pooled_output
+
+
+def get_hidden_states_sdxl(
+    max_token_length: int,
+    input_ids1: torch.Tensor,
+    input_ids2: torch.Tensor,
+    tokenizer1: CLIPTokenizer,
+    tokenizer2: CLIPTokenizer,
+    text_encoder1: CLIPTextModel,
+    text_encoder2: CLIPTextModelWithProjection,
+    weight_dtype: Optional[str] = None,
+    accelerator: Optional[Accelerator] = None,
+):
+    # input_ids: b,n,77 -> b*n, 77
+    b_size = input_ids1.size()[0]
+    input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length))  # batch_size*n, 77
+    input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length))  # batch_size*n, 77
+
+    # text_encoder1
+    enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
+    hidden_states1 = enc_out["hidden_states"][11]
+
+    # text_encoder2
+    enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
+    hidden_states2 = enc_out["hidden_states"][-2]  # penuultimate layer
+
+    # pool2 = enc_out["text_embeds"]
+    unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
+    pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
+
+    # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
+    n_size = 1 if max_token_length is None else max_token_length // 75
+    hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
+    hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
+
+    if max_token_length is not None:
+        # bs*3, 77, 768 or 1024
+        # encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
+        states_list = [hidden_states1[:, 0].unsqueeze(1)]  # <BOS>
+        for i in range(1, max_token_length, tokenizer1.model_max_length):
+            states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2])  # <BOS> の後から <EOS> の前まで
+        states_list.append(hidden_states1[:, -1].unsqueeze(1))  # <EOS>
+        hidden_states1 = torch.cat(states_list, dim=1)
+
+        # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
+        states_list = [hidden_states2[:, 0].unsqueeze(1)]  # <BOS>
+        for i in range(1, max_token_length, tokenizer2.model_max_length):
+            chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2]  # <BOS> の後から 最後の前まで
+            # this causes an error:
+            # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+            # if i > 1:
+            #     for j in range(len(chunk)):  # batch_size
+            #         if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id:  # 空、つまり <BOS> <EOS> <PAD> ...のパターン
+            #             chunk[j, 0] = chunk[j, 1]  # 次の <PAD> の値をコピーする
+            states_list.append(chunk)  # <BOS> の後から <EOS> の前まで
+        states_list.append(hidden_states2[:, -1].unsqueeze(1))  # <EOS> か <PAD> のどちらか
+        hidden_states2 = torch.cat(states_list, dim=1)
+
+        # pool はnの最初のものを使う
+        pool2 = pool2[::n_size]
+
+    if weight_dtype is not None:
+        # this is required for additional network training
+        hidden_states1 = hidden_states1.to(weight_dtype)
+        hidden_states2 = hidden_states2.to(weight_dtype)
+
+    return hidden_states1, hidden_states2, pool2
+
+
+def default_if_none(value, default):
+    return default if value is None else value
+
+
+def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int):
+    model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
+    return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext
+
+
+def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int):
+    model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
+    return STEP_FILE_NAME.format(model_name, step_no) + ext
+
+
+def get_last_ckpt_name(args: argparse.Namespace, ext: str):
+    model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
+    return model_name + ext
+
+
+def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
+    if args.save_last_n_epochs is None:
+        return None
+
+    remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
+    if remove_epoch_no < 0:
+        return None
+    return remove_epoch_no
+
+
+def get_remove_step_no(args: argparse.Namespace, step_no: int):
+    if args.save_last_n_steps is None:
+        return None
+
+    # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
+    # save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する
+    remove_step_no = step_no - args.save_last_n_steps - 1
+    remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
+    if remove_step_no < 0:
+        return None
+    return remove_step_no
+
+
+# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
+# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
+def save_sd_model_on_epoch_end_or_stepwise(
+    args: argparse.Namespace,
+    on_epoch_end: bool,
+    accelerator,
+    src_path: str,
+    save_stable_diffusion_format: bool,
+    use_safetensors: bool,
+    save_dtype: torch.dtype,
+    epoch: int,
+    num_train_epochs: int,
+    global_step: int,
+    text_encoder,
+    unet,
+    vae,
+):
+    def sd_saver(ckpt_file, epoch_no, global_step):
+        sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
+        model_util.save_stable_diffusion_checkpoint(
+            args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
+        )
+
+    def diffusers_saver(out_dir):
+        model_util.save_diffusers_checkpoint(
+            args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
+        )
+
+    save_sd_model_on_epoch_end_or_stepwise_common(
+        args,
+        on_epoch_end,
+        accelerator,
+        save_stable_diffusion_format,
+        use_safetensors,
+        epoch,
+        num_train_epochs,
+        global_step,
+        sd_saver,
+        diffusers_saver,
+    )
+
+
+def save_sd_model_on_epoch_end_or_stepwise_common(
+    args: argparse.Namespace,
+    on_epoch_end: bool,
+    accelerator,
+    save_stable_diffusion_format: bool,
+    use_safetensors: bool,
+    epoch: int,
+    num_train_epochs: int,
+    global_step: int,
+    sd_saver,
+    diffusers_saver,
+):
+    if on_epoch_end:
+        epoch_no = epoch + 1
+        saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
+        if not saving:
+            return
+
+        model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
+        remove_no = get_remove_epoch_no(args, epoch_no)
+    else:
+        # 保存するか否かは呼び出し側で判断済み
+
+        model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
+        epoch_no = epoch  # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される
+        remove_no = get_remove_step_no(args, global_step)
+
+    os.makedirs(args.output_dir, exist_ok=True)
+    if save_stable_diffusion_format:
+        ext = ".safetensors" if use_safetensors else ".ckpt"
+
+        if on_epoch_end:
+            ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
+        else:
+            ckpt_name = get_step_ckpt_name(args, ext, global_step)
+
+        ckpt_file = os.path.join(args.output_dir, ckpt_name)
+        print(f"\nsaving checkpoint: {ckpt_file}")
+        sd_saver(ckpt_file, epoch_no, global_step)
+
+        if args.huggingface_repo_id is not None:
+            huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
+
+        # remove older checkpoints
+        if remove_no is not None:
+            if on_epoch_end:
+                remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no)
+            else:
+                remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no)
+
+            remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name)
+            if os.path.exists(remove_ckpt_file):
+                print(f"removing old checkpoint: {remove_ckpt_file}")
+                os.remove(remove_ckpt_file)
+
+    else:
+        if on_epoch_end:
+            out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
+        else:
+            out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
+
+        print(f"\nsaving model: {out_dir}")
+        diffusers_saver(out_dir)
+
+        if args.huggingface_repo_id is not None:
+            huggingface_util.upload(args, out_dir, "/" + model_name)
+
+        # remove older checkpoints
+        if remove_no is not None:
+            if on_epoch_end:
+                remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
+            else:
+                remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
+
+            if os.path.exists(remove_out_dir):
+                print(f"removing old model: {remove_out_dir}")
+                shutil.rmtree(remove_out_dir)
+
+    if args.save_state:
+        if on_epoch_end:
+            save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
+        else:
+            save_and_remove_state_stepwise(args, accelerator, global_step)
+
+
+def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
+    model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
+
+    print(f"\nsaving state at epoch {epoch_no}")
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
+    accelerator.save_state(state_dir)
+    if args.save_state_to_huggingface:
+        print("uploading state to huggingface.")
+        huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
+
+    last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
+    if last_n_epochs is not None:
+        remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
+        state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
+        if os.path.exists(state_dir_old):
+            print(f"removing old state: {state_dir_old}")
+            shutil.rmtree(state_dir_old)
+
+
+def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
+    model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
+
+    print(f"\nsaving state at step {step_no}")
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
+    accelerator.save_state(state_dir)
+    if args.save_state_to_huggingface:
+        print("uploading state to huggingface.")
+        huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
+
+    last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
+    if last_n_steps is not None:
+        # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
+        remove_step_no = step_no - last_n_steps - 1
+        remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
+
+        if remove_step_no > 0:
+            state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
+            if os.path.exists(state_dir_old):
+                print(f"removing old state: {state_dir_old}")
+                shutil.rmtree(state_dir_old)
+
+
+def save_state_on_train_end(args: argparse.Namespace, accelerator):
+    model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
+
+    print("\nsaving last state.")
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
+    accelerator.save_state(state_dir)
+
+    if args.save_state_to_huggingface:
+        print("uploading last state to huggingface.")
+        huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
+
+
+def save_sd_model_on_train_end(
+    args: argparse.Namespace,
+    src_path: str,
+    save_stable_diffusion_format: bool,
+    use_safetensors: bool,
+    save_dtype: torch.dtype,
+    epoch: int,
+    global_step: int,
+    text_encoder,
+    unet,
+    vae,
+):
+    def sd_saver(ckpt_file, epoch_no, global_step):
+        sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
+        model_util.save_stable_diffusion_checkpoint(
+            args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae
+        )
+
+    def diffusers_saver(out_dir):
+        model_util.save_diffusers_checkpoint(
+            args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
+        )
+
+    save_sd_model_on_train_end_common(
+        args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
+    )
+
+
+def save_sd_model_on_train_end_common(
+    args: argparse.Namespace,
+    save_stable_diffusion_format: bool,
+    use_safetensors: bool,
+    epoch: int,
+    global_step: int,
+    sd_saver,
+    diffusers_saver,
+):
+    model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
+
+    if save_stable_diffusion_format:
+        os.makedirs(args.output_dir, exist_ok=True)
+
+        ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
+        ckpt_file = os.path.join(args.output_dir, ckpt_name)
+
+        print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
+        sd_saver(ckpt_file, epoch, global_step)
+
+        if args.huggingface_repo_id is not None:
+            huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
+    else:
+        out_dir = os.path.join(args.output_dir, model_name)
+        os.makedirs(out_dir, exist_ok=True)
+
+        print(f"save trained model as Diffusers to {out_dir}")
+        diffusers_saver(out_dir)
+
+        if args.huggingface_repo_id is not None:
+            huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
+
+
+def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
+    # Sample noise that we'll add to the latents
+    noise = torch.randn_like(latents, device=latents.device)
+    if args.noise_offset:
+        noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
+    if args.multires_noise_iterations:
+        noise = custom_train_functions.pyramid_noise_like(
+            noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
+        )
+
+    # Sample a random timestep for each image
+    b_size = latents.shape[0]
+    min_timestep = 0 if args.min_timestep is None else args.min_timestep
+    max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
+
+    timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
+    timesteps = timesteps.long()
+
+    # Add noise to the latents according to the noise magnitude at each timestep
+    # (this is the forward diffusion process)
+    if args.ip_noise_gamma:
+        noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
+    else:
+        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+    return noise, noisy_latents, timesteps
+
+
+def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
+    names = []
+    if including_unet:
+        names.append("unet")
+    names.append("text_encoder1")
+    names.append("text_encoder2")
+
+    append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
+
+
+def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
+    lrs = lr_scheduler.get_last_lr()
+
+    for lr_index in range(len(lrs)):
+        name = names[lr_index]
+        logs["lr/" + name] = float(lrs[lr_index])
+
+        if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
+            logs["lr/d*lr/" + name] = (
+                lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
+            )
+
+
+# scheduler:
+SCHEDULER_LINEAR_START = 0.00085
+SCHEDULER_LINEAR_END = 0.0120
+SCHEDULER_TIMESTEPS = 1000
+SCHEDLER_SCHEDULE = "scaled_linear"
+
+
+def get_my_scheduler(
+    *,
+    sample_sampler: str,
+    v_parameterization: bool,
+):
+    sched_init_args = {}
+    if sample_sampler == "ddim":
+        scheduler_cls = DDIMScheduler
+    elif sample_sampler == "ddpm":  # ddpmはおかしくなるのでoptionから外してある
+        scheduler_cls = DDPMScheduler
+    elif sample_sampler == "pndm":
+        scheduler_cls = PNDMScheduler
+    elif sample_sampler == "lms" or sample_sampler == "k_lms":
+        scheduler_cls = LMSDiscreteScheduler
+    elif sample_sampler == "euler" or sample_sampler == "k_euler":
+        scheduler_cls = EulerDiscreteScheduler
+    elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
+        scheduler_cls = EulerAncestralDiscreteScheduler
+    elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
+        scheduler_cls = DPMSolverMultistepScheduler
+        sched_init_args["algorithm_type"] = sample_sampler
+    elif sample_sampler == "dpmsingle":
+        scheduler_cls = DPMSolverSinglestepScheduler
+    elif sample_sampler == "heun":
+        scheduler_cls = HeunDiscreteScheduler
+    elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
+        scheduler_cls = KDPM2DiscreteScheduler
+    elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
+        scheduler_cls = KDPM2AncestralDiscreteScheduler
+    else:
+        scheduler_cls = DDIMScheduler
+
+    if v_parameterization:
+        sched_init_args["prediction_type"] = "v_prediction"
+
+    scheduler = scheduler_cls(
+        num_train_timesteps=SCHEDULER_TIMESTEPS,
+        beta_start=SCHEDULER_LINEAR_START,
+        beta_end=SCHEDULER_LINEAR_END,
+        beta_schedule=SCHEDLER_SCHEDULE,
+        **sched_init_args,
+    )
+
+    # clip_sample=Trueにする
+    if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
+        # print("set clip_sample to True")
+        scheduler.config.clip_sample = True
+
+    return scheduler
+
+
+def sample_images(*args, **kwargs):
+    return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
+
+
+def line_to_prompt_dict(line: str) -> dict:
+    # subset of gen_img_diffusers
+    prompt_args = line.split(" --")
+    prompt_dict = {}
+    prompt_dict["prompt"] = prompt_args[0]
+
+    for parg in prompt_args:
+        try:
+            m = re.match(r"w (\d+)", parg, re.IGNORECASE)
+            if m:
+                prompt_dict["width"] = int(m.group(1))
+                continue
+
+            m = re.match(r"h (\d+)", parg, re.IGNORECASE)
+            if m:
+                prompt_dict["height"] = int(m.group(1))
+                continue
+
+            m = re.match(r"d (\d+)", parg, re.IGNORECASE)
+            if m:
+                prompt_dict["seed"] = int(m.group(1))
+                continue
+
+            m = re.match(r"s (\d+)", parg, re.IGNORECASE)
+            if m:  # steps
+                prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
+                continue
+
+            m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
+            if m:  # scale
+                prompt_dict["scale"] = float(m.group(1))
+                continue
+
+            m = re.match(r"n (.+)", parg, re.IGNORECASE)
+            if m:  # negative prompt
+                prompt_dict["negative_prompt"] = m.group(1)
+                continue
+
+            m = re.match(r"ss (.+)", parg, re.IGNORECASE)
+            if m:
+                prompt_dict["sample_sampler"] = m.group(1)
+                continue
+
+            m = re.match(r"cn (.+)", parg, re.IGNORECASE)
+            if m:
+                prompt_dict["controlnet_image"] = m.group(1)
+                continue
+
+        except ValueError as ex:
+            print(f"Exception in parsing / 解析エラー: {parg}")
+            print(ex)
+
+    return prompt_dict
+
+
+def sample_images_common(
+    pipe_class,
+    accelerator: Accelerator,
+    args: argparse.Namespace,
+    epoch,
+    steps,
+    device,
+    vae,
+    tokenizer,
+    text_encoder,
+    unet,
+    prompt_replacement=None,
+    controlnet=None,
+):
+    """
+    StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
+    """
+    if steps == 0:
+        if not args.sample_at_first:
+            return
+    else:
+        if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
+            return
+        if args.sample_every_n_epochs is not None:
+            # sample_every_n_steps は無視する
+            if epoch is None or epoch % args.sample_every_n_epochs != 0:
+                return
+        else:
+            if steps % args.sample_every_n_steps != 0 or epoch is not None:  # steps is not divisible or end of epoch
+                return
+
+    print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
+    if not os.path.isfile(args.sample_prompts):
+        print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
+        return
+
+    org_vae_device = vae.device  # CPUにいるはず
+    vae.to(device)
+
+    # unwrap unet and text_encoder(s)
+    unet = accelerator.unwrap_model(unet)
+    if isinstance(text_encoder, (list, tuple)):
+        text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
+    else:
+        text_encoder = accelerator.unwrap_model(text_encoder)
+
+    # read prompts
+
+    # with open(args.sample_prompts, "rt", encoding="utf-8") as f:
+    #     prompts = f.readlines()
+
+    if args.sample_prompts.endswith(".txt"):
+        with open(args.sample_prompts, "r", encoding="utf-8") as f:
+            lines = f.readlines()
+        prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
+    elif args.sample_prompts.endswith(".toml"):
+        with open(args.sample_prompts, "r", encoding="utf-8") as f:
+            data = toml.load(f)
+        prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
+    elif args.sample_prompts.endswith(".json"):
+        with open(args.sample_prompts, "r", encoding="utf-8") as f:
+            prompts = json.load(f)
+
+    schedulers: dict = {}
+    default_scheduler = get_my_scheduler(
+        sample_sampler=args.sample_sampler,
+        v_parameterization=args.v_parameterization,
+    )
+    schedulers[args.sample_sampler] = default_scheduler
+
+    pipeline = pipe_class(
+        text_encoder=text_encoder,
+        vae=vae,
+        unet=unet,
+        tokenizer=tokenizer,
+        scheduler=default_scheduler,
+        safety_checker=None,
+        feature_extractor=None,
+        requires_safety_checker=False,
+        clip_skip=args.clip_skip,
+    )
+    pipeline.to(device)
+
+    save_dir = args.output_dir + "/sample"
+    os.makedirs(save_dir, exist_ok=True)
+
+    rng_state = torch.get_rng_state()
+    cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
+
+    with torch.no_grad():
+        # with accelerator.autocast():
+        for i, prompt_dict in enumerate(prompts):
+            if not accelerator.is_main_process:
+                continue
+
+            if isinstance(prompt_dict, str):
+                prompt_dict = line_to_prompt_dict(prompt_dict)
+
+            assert isinstance(prompt_dict, dict)
+            negative_prompt = prompt_dict.get("negative_prompt")
+            sample_steps = prompt_dict.get("sample_steps", 30)
+            width = prompt_dict.get("width", 512)
+            height = prompt_dict.get("height", 512)
+            scale = prompt_dict.get("scale", 7.5)
+            seed = prompt_dict.get("seed")
+            controlnet_image = prompt_dict.get("controlnet_image")
+            prompt: str = prompt_dict.get("prompt", "")
+            sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
+
+            if seed is not None:
+                torch.manual_seed(seed)
+                torch.cuda.manual_seed(seed)
+
+            scheduler = schedulers.get(sampler_name)
+            if scheduler is None:
+                scheduler = get_my_scheduler(
+                    sample_sampler=sampler_name,
+                    v_parameterization=args.v_parameterization,
+                )
+                schedulers[sampler_name] = scheduler
+            pipeline.scheduler = scheduler
+
+            if prompt_replacement is not None:
+                prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
+                if negative_prompt is not None:
+                    negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
+
+            if controlnet_image is not None:
+                controlnet_image = Image.open(controlnet_image).convert("RGB")
+                controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
+
+            height = max(64, height - height % 8)  # round to divisible by 8
+            width = max(64, width - width % 8)  # round to divisible by 8
+            print(f"prompt: {prompt}")
+            print(f"negative_prompt: {negative_prompt}")
+            print(f"height: {height}")
+            print(f"width: {width}")
+            print(f"sample_steps: {sample_steps}")
+            print(f"scale: {scale}")
+            print(f"sample_sampler: {sampler_name}")
+            if seed is not None:
+                print(f"seed: {seed}")
+            with accelerator.autocast():
+                latents = pipeline(
+                    prompt=prompt,
+                    height=height,
+                    width=width,
+                    num_inference_steps=sample_steps,
+                    guidance_scale=scale,
+                    negative_prompt=negative_prompt,
+                    controlnet=controlnet,
+                    controlnet_image=controlnet_image,
+                )
+
+            image = pipeline.latents_to_image(latents)[0]
+
+            ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
+            num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
+            seed_suffix = "" if seed is None else f"_{seed}"
+            img_filename = (
+                f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
+            )
+
+            image.save(os.path.join(save_dir, img_filename))
+
+            # wandb有効時のみログを送信
+            try:
+                wandb_tracker = accelerator.get_tracker("wandb")
+                try:
+                    import wandb
+                except ImportError:  # 事前に一度確認するのでここはエラー出ないはず
+                    raise ImportError("No wandb / wandb がインストールされていないようです")
+
+                wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
+            except:  # wandb 無効時
+                pass
+
+    # clear pipeline and cache to reduce vram usage
+    del pipeline
+    torch.cuda.empty_cache()
+
+    torch.set_rng_state(rng_state)
+    if cuda_rng_state is not None:
+        torch.cuda.set_rng_state(cuda_rng_state)
+    vae.to(org_vae_device)
+
+
+# endregion
+
+# region 前処理用
+
+
+class ImageLoadingDataset(torch.utils.data.Dataset):
+    def __init__(self, image_paths):
+        self.images = image_paths
+
+    def __len__(self):
+        return len(self.images)
+
+    def __getitem__(self, idx):
+        img_path = self.images[idx]
+
+        try:
+            image = Image.open(img_path).convert("RGB")
+            # convert to tensor temporarily so dataloader will accept it
+            tensor_pil = transforms.functional.pil_to_tensor(image)
+        except Exception as e:
+            print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
+            return None
+
+        return (tensor_pil, img_path)
+
+
+# endregion
+
+
+# collate_fn用 epoch,stepはmultiprocessing.Value
+class collator_class:
+    def __init__(self, epoch, step, dataset):
+        self.current_epoch = epoch
+        self.current_step = step
+        self.dataset = dataset  # not used if worker_info is not None, in case of multiprocessing
+
+    def __call__(self, examples):
+        worker_info = torch.utils.data.get_worker_info()
+        # worker_info is None in the main process
+        if worker_info is not None:
+            dataset = worker_info.dataset
+        else:
+            dataset = self.dataset
+
+        # set epoch and step
+        dataset.set_current_epoch(self.current_epoch.value)
+        dataset.set_current_step(self.current_step.value)
+        return examples[0]
+
+
+class LossRecorder:
+    def __init__(self):
+        self.loss_list: List[float] = []
+        self.loss_total: float = 0.0
+
+    def add(self, *, epoch: int, step: int, loss: float) -> None:
+        if epoch == 0:
+            self.loss_list.append(loss)
+        else:
+            self.loss_total -= self.loss_list[step]
+            self.loss_list[step] = loss
+        self.loss_total += loss
+
+    @property
+    def moving_average(self) -> float:
+        return self.loss_total / len(self.loss_list)
diff --git a/external/llite/library/utils.py b/external/llite/library/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d801a676da99e7815957dbe6d25a878ae8dfcfc
--- /dev/null
+++ b/external/llite/library/utils.py
@@ -0,0 +1,6 @@
+import threading
+from typing import *
+
+
+def fire_in_thread(f, *args, **kwargs):
+    threading.Thread(target=f, args=args, kwargs=kwargs).start()
\ No newline at end of file
diff --git a/external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py b/external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..85346b490adf80e71fbdfdc258a60f8c2d9bc5b1
--- /dev/null
+++ b/external/llite/networks/.ipynb_checkpoints/control_net_lllite-checkpoint.py
@@ -0,0 +1,446 @@
+import os
+from typing import Optional, List, Type
+import torch
+from external.llite.library import sdxl_original_unet
+
+
+# input_blocksに適用するかどうか / if True, input_blocks are not applied
+SKIP_INPUT_BLOCKS = False
+
+# output_blocksに適用するかどうか / if True, output_blocks are not applied
+SKIP_OUTPUT_BLOCKS = True
+
+# conv2dに適用するかどうか / if True, conv2d are not applied
+SKIP_CONV2D = False
+
+# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
+# if True, only transformer_blocks are applied, and ResBlocks are not applied
+TRANSFORMER_ONLY = True  # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
+
+# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
+ATTN1_2_ONLY = True
+
+# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
+ATTN_QKV_ONLY = True
+
+# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
+# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
+ATTN1_ETC_ONLY = False  # True
+
+# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
+# max index of transformer_blocks. if None, apply to all transformer_blocks
+TRANSFORMER_MAX_BLOCK_INDEX = None
+
+
+class LLLiteModule(torch.nn.Module):
+    def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
+        super().__init__()
+
+        self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
+        self.lllite_name = name
+        self.cond_emb_dim = cond_emb_dim
+        self.org_module = [org_module]
+        self.dropout = dropout
+        self.multiplier = multiplier
+
+        if self.is_conv2d:
+            in_dim = org_module.in_channels
+        else:
+            in_dim = org_module.in_features
+
+        # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
+        # conditioning1 embeds conditioning image. it is not called for each timestep
+        modules = []
+        modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))  # to latent (from VAE) size
+        if depth == 1:
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+        elif depth == 2:
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
+        elif depth == 3:
+            # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+
+        self.conditioning1 = torch.nn.Sequential(*modules)
+
+        # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
+        # midでconditioning image embeddingと入力を結合する
+        # upで元の次元数に戻す
+        # これらはtimestepごとに呼ばれる
+        # reduce the number of input dimensions with down. inspired by LoRA
+        # combine conditioning image embedding and input with mid
+        # restore to the original dimension with up
+        # these are called for each timestep
+
+        if self.is_conv2d:
+            self.down = torch.nn.Sequential(
+                torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.mid = torch.nn.Sequential(
+                torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.up = torch.nn.Sequential(
+                torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
+            )
+        else:
+            # midの前にconditioningをreshapeすること / reshape conditioning before mid
+            self.down = torch.nn.Sequential(
+                torch.nn.Linear(in_dim, mlp_dim),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.mid = torch.nn.Sequential(
+                torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.up = torch.nn.Sequential(
+                torch.nn.Linear(mlp_dim, in_dim),
+            )
+
+        # Zero-Convにする / set to Zero-Conv
+        torch.nn.init.zeros_(self.up[0].weight)  # zero conv
+
+        self.depth = depth  # 1~3
+        self.cond_emb = None
+        self.batch_cond_only = False  # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
+        self.use_zeros_for_batch_uncond = False  # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
+
+        # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
+        # Controlの種類によっては使えるかも
+        # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
+        # it may be available depending on the type of Control
+
+    def set_cond_image(self, cond_image):
+        r"""
+        中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
+        / call the model inside, so if necessary, surround it with torch.no_grad()
+        """
+        if cond_image is None:
+            self.cond_emb = None
+            return
+
+        # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
+        # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
+        cx = self.conditioning1(cond_image)
+        if not self.is_conv2d:
+            # reshape / b,c,h,w -> b,h*w,c
+            n, c, h, w = cx.shape
+            cx = cx.view(n, c, h * w).permute(0, 2, 1)
+        self.cond_emb = cx
+
+    def set_batch_cond_only(self, cond_only, zeros):
+        self.batch_cond_only = cond_only
+        self.use_zeros_for_batch_uncond = zeros
+
+    def apply_to(self):
+        self.org_forward = self.org_module[0].forward
+        self.org_module[0].forward = self.forward
+
+    def forward(self, x):
+        r"""
+        学習用の便利forward。元のモジュールのforwardを呼び出す
+        / convenient forward for training. call the forward of the original module
+        """
+        if self.multiplier == 0.0 or self.cond_emb is None:
+            return self.org_forward(x)
+
+        cx = self.cond_emb
+
+        if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]:  # inference only
+            cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
+            if self.use_zeros_for_batch_uncond:
+                cx[0::2] = 0.0  # uncond is zero
+        # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
+
+        # downで入力の次元数を削減し、conditioning image embeddingと結合する
+        # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
+        # down reduces the number of input dimensions and combines it with conditioning image embedding
+        # we expect that it will mix well by combining in the channel direction instead of adding
+
+        cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
+        cx = self.mid(cx)
+
+        if self.dropout is not None and self.training:
+            cx = torch.nn.functional.dropout(cx, p=self.dropout)
+
+        cx = self.up(cx) * self.multiplier
+
+        # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
+        if self.batch_cond_only:
+            zx = torch.zeros_like(x)
+            zx[1::2] += cx
+            cx = zx
+
+        x = self.org_forward(x + cx)  # ここで元のモジュールを呼び出す / call the original module here
+        return x
+
+
+class ControlNetLLLite(torch.nn.Module):
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+
+    def __init__(
+        self,
+        unet: sdxl_original_unet.SdxlUNet2DConditionModel,
+        cond_emb_dim: int = 16,
+        mlp_dim: int = 16,
+        dropout: Optional[float] = None,
+        varbose: Optional[bool] = False,
+        multiplier: Optional[float] = 1.0,
+    ) -> None:
+        super().__init__()
+        # self.unets = [unet]
+
+        def create_modules(
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+            module_class: Type[object],
+        ) -> List[torch.nn.Module]:
+            prefix = "lllite_unet"
+
+            modules = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "Linear"
+                        is_conv2d = child_module.__class__.__name__ == "Conv2d"
+                        if is_linear or (is_conv2d and not SKIP_CONV2D):
+                            # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
+                            # block index to depth: depth is using to calculate conditioning size and channels
+                            block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
+                            index1 = int(index1)
+                            if block_name == "input_blocks":
+                                if SKIP_INPUT_BLOCKS:
+                                    continue
+                                depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
+                            elif block_name == "middle_block":
+                                depth = 3
+                            elif block_name == "output_blocks":
+                                if SKIP_OUTPUT_BLOCKS:
+                                    continue
+                                depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
+                                if int(index2) >= 2:
+                                    depth -= 1
+                            else:
+                                raise NotImplementedError()
+
+                            lllite_name = prefix + "." + name + "." + child_name
+                            lllite_name = lllite_name.replace(".", "_")
+
+                            if TRANSFORMER_MAX_BLOCK_INDEX is not None:
+                                p = lllite_name.find("transformer_blocks")
+                                if p >= 0:
+                                    tf_index = int(lllite_name[p:].split("_")[2])
+                                    if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
+                                        continue
+
+                            #  time embは適用外とする
+                            # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
+                            # time emb is not applied
+                            # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
+                            if "emb_layers" in lllite_name or (
+                                "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
+                            ):
+                                continue
+
+                            if ATTN1_2_ONLY:
+                                if not ("attn1" in lllite_name or "attn2" in lllite_name):
+                                    continue
+                                if ATTN_QKV_ONLY:
+                                    if "to_out" in lllite_name:
+                                        continue
+
+                            if ATTN1_ETC_ONLY:
+                                if "proj_out" in lllite_name:
+                                    pass
+                                elif "attn1" in lllite_name and (
+                                    "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
+                                ):
+                                    pass
+                                elif "ff_net_2" in lllite_name:
+                                    pass
+                                else:
+                                    continue
+
+                            module = module_class(
+                                depth,
+                                cond_emb_dim,
+                                lllite_name,
+                                child_module,
+                                mlp_dim,
+                                dropout=dropout,
+                                multiplier=multiplier,
+                            )
+                            modules.append(module)
+            print(f"Returning {len(modules)} modules for llite net")
+            return modules
+
+        target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
+        if not TRANSFORMER_ONLY:
+            target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        # create module instances
+        self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
+        print(f"created ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
+
+    def forward(self, x):
+        return x  # dummy
+
+    def set_cond_image(self, cond_image):
+        r"""
+        中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
+        / call the model inside, so if necessary, surround it with torch.no_grad()
+        """
+        for module in self.unet_modules:
+            module.set_cond_image(cond_image)
+
+    def set_batch_cond_only(self, cond_only, zeros):
+        for module in self.unet_modules:
+            module.set_batch_cond_only(cond_only, zeros)
+
+    def set_multiplier(self, multiplier):
+        for module in self.unet_modules:
+            module.multiplier = multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def apply_to(self):
+        print("applying LLLite for U-Net...")
+        for module in self.unet_modules:
+            module.apply_to()
+            self.add_module(module.lllite_name, module)
+
+    # マージできるかどうかを返す
+    def is_mergeable(self):
+        return False
+
+    def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
+        raise NotImplementedError()
+
+    def enable_gradient_checkpointing(self):
+        # not supported
+        pass
+
+    def prepare_optimizer_params(self):
+        self.requires_grad_(True)
+        return self.parameters()
+
+    def prepare_grad_etc(self):
+        self.requires_grad_(True)
+
+    def on_epoch_start(self):
+        self.train()
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+
+if __name__ == "__main__":
+    # デバッグ用 / for debug
+
+    # sdxl_original_unet.USE_REENTRANT = False
+
+    # test shape etc
+    print("create unet")
+    unet = sdxl_original_unet.SdxlUNet2DConditionModel()
+    unet.to("cuda").to(torch.float16)
+
+    print("create ControlNet-LLLite")
+    control_net = ControlNetLLLite(unet, 32, 64)
+    control_net.apply_to()
+    control_net.to("cuda")
+
+    print(control_net)
+
+    # print number of parameters
+    print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
+
+    input()
+
+    unet.set_use_memory_efficient_attention(True, False)
+    unet.set_gradient_checkpointing(True)
+    unet.train()  # for gradient checkpointing
+
+    control_net.train()
+
+    # # visualize
+    # import torchviz
+    # print("run visualize")
+    # controlnet.set_control(conditioning_image)
+    # output = unet(x, t, ctx, y)
+    # print("make_dot")
+    # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
+    # print("render")
+    # image.format = "svg" # "png"
+    # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
+    # input()
+
+    import bitsandbytes
+
+    optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
+
+    scaler = torch.cuda.amp.GradScaler(enabled=True)
+
+    print("start training")
+    steps = 10
+
+    sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
+    for step in range(steps):
+        print(f"step {step}")
+
+        batch_size = 1
+        conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
+        x = torch.randn(batch_size, 4, 128, 128).cuda()
+        t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
+        ctx = torch.randn(batch_size, 77, 2048).cuda()
+        y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
+
+        with torch.cuda.amp.autocast(enabled=True):
+            control_net.set_cond_image(conditioning_image)
+
+            output = unet(x, t, ctx, y)
+            target = torch.randn_like(output)
+            loss = torch.nn.functional.mse_loss(output, target)
+
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad(set_to_none=True)
+        print(sample_param)
+
+    # from safetensors.torch import save_file
+
+    # save_file(control_net.state_dict(), "logs/control_net.safetensors")
diff --git a/external/llite/networks/check_lora_weights.py b/external/llite/networks/check_lora_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..51f581b29802eb6452bffe1d2cf9a26d88552e38
--- /dev/null
+++ b/external/llite/networks/check_lora_weights.py
@@ -0,0 +1,45 @@
+import argparse
+import os
+import torch
+from safetensors.torch import load_file
+
+
+def main(file):
+    print(f"loading: {file}")
+    if os.path.splitext(file)[1] == ".safetensors":
+        sd = load_file(file)
+    else:
+        sd = torch.load(file, map_location="cpu")
+
+    values = []
+
+    keys = list(sd.keys())
+    for key in keys:
+        if "lora_up" in key or "lora_down" in key:
+            values.append((key, sd[key]))
+    print(f"number of LoRA modules: {len(values)}")
+
+    if args.show_all_keys:
+        for key in [k for k in keys if k not in values]:
+            values.append((key, sd[key]))
+        print(f"number of all modules: {len(values)}")
+
+    for key, value in values:
+        value = value.to(torch.float32)
+        print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
+    parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
+
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+
+    main(args.file)
diff --git a/external/llite/networks/control_net_lllite.py b/external/llite/networks/control_net_lllite.py
new file mode 100644
index 0000000000000000000000000000000000000000..85346b490adf80e71fbdfdc258a60f8c2d9bc5b1
--- /dev/null
+++ b/external/llite/networks/control_net_lllite.py
@@ -0,0 +1,446 @@
+import os
+from typing import Optional, List, Type
+import torch
+from external.llite.library import sdxl_original_unet
+
+
+# input_blocksに適用するかどうか / if True, input_blocks are not applied
+SKIP_INPUT_BLOCKS = False
+
+# output_blocksに適用するかどうか / if True, output_blocks are not applied
+SKIP_OUTPUT_BLOCKS = True
+
+# conv2dに適用するかどうか / if True, conv2d are not applied
+SKIP_CONV2D = False
+
+# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
+# if True, only transformer_blocks are applied, and ResBlocks are not applied
+TRANSFORMER_ONLY = True  # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
+
+# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
+ATTN1_2_ONLY = True
+
+# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
+ATTN_QKV_ONLY = True
+
+# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
+# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
+ATTN1_ETC_ONLY = False  # True
+
+# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
+# max index of transformer_blocks. if None, apply to all transformer_blocks
+TRANSFORMER_MAX_BLOCK_INDEX = None
+
+
+class LLLiteModule(torch.nn.Module):
+    def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
+        super().__init__()
+
+        self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
+        self.lllite_name = name
+        self.cond_emb_dim = cond_emb_dim
+        self.org_module = [org_module]
+        self.dropout = dropout
+        self.multiplier = multiplier
+
+        if self.is_conv2d:
+            in_dim = org_module.in_channels
+        else:
+            in_dim = org_module.in_features
+
+        # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
+        # conditioning1 embeds conditioning image. it is not called for each timestep
+        modules = []
+        modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))  # to latent (from VAE) size
+        if depth == 1:
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+        elif depth == 2:
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
+        elif depth == 3:
+            # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
+            modules.append(torch.nn.ReLU(inplace=True))
+            modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+
+        self.conditioning1 = torch.nn.Sequential(*modules)
+
+        # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
+        # midでconditioning image embeddingと入力を結合する
+        # upで元の次元数に戻す
+        # これらはtimestepごとに呼ばれる
+        # reduce the number of input dimensions with down. inspired by LoRA
+        # combine conditioning image embedding and input with mid
+        # restore to the original dimension with up
+        # these are called for each timestep
+
+        if self.is_conv2d:
+            self.down = torch.nn.Sequential(
+                torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.mid = torch.nn.Sequential(
+                torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.up = torch.nn.Sequential(
+                torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
+            )
+        else:
+            # midの前にconditioningをreshapeすること / reshape conditioning before mid
+            self.down = torch.nn.Sequential(
+                torch.nn.Linear(in_dim, mlp_dim),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.mid = torch.nn.Sequential(
+                torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
+                torch.nn.ReLU(inplace=True),
+            )
+            self.up = torch.nn.Sequential(
+                torch.nn.Linear(mlp_dim, in_dim),
+            )
+
+        # Zero-Convにする / set to Zero-Conv
+        torch.nn.init.zeros_(self.up[0].weight)  # zero conv
+
+        self.depth = depth  # 1~3
+        self.cond_emb = None
+        self.batch_cond_only = False  # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
+        self.use_zeros_for_batch_uncond = False  # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
+
+        # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
+        # Controlの種類によっては使えるかも
+        # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
+        # it may be available depending on the type of Control
+
+    def set_cond_image(self, cond_image):
+        r"""
+        中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
+        / call the model inside, so if necessary, surround it with torch.no_grad()
+        """
+        if cond_image is None:
+            self.cond_emb = None
+            return
+
+        # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
+        # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
+        cx = self.conditioning1(cond_image)
+        if not self.is_conv2d:
+            # reshape / b,c,h,w -> b,h*w,c
+            n, c, h, w = cx.shape
+            cx = cx.view(n, c, h * w).permute(0, 2, 1)
+        self.cond_emb = cx
+
+    def set_batch_cond_only(self, cond_only, zeros):
+        self.batch_cond_only = cond_only
+        self.use_zeros_for_batch_uncond = zeros
+
+    def apply_to(self):
+        self.org_forward = self.org_module[0].forward
+        self.org_module[0].forward = self.forward
+
+    def forward(self, x):
+        r"""
+        学習用の便利forward。元のモジュールのforwardを呼び出す
+        / convenient forward for training. call the forward of the original module
+        """
+        if self.multiplier == 0.0 or self.cond_emb is None:
+            return self.org_forward(x)
+
+        cx = self.cond_emb
+
+        if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]:  # inference only
+            cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
+            if self.use_zeros_for_batch_uncond:
+                cx[0::2] = 0.0  # uncond is zero
+        # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
+
+        # downで入力の次元数を削減し、conditioning image embeddingと結合する
+        # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
+        # down reduces the number of input dimensions and combines it with conditioning image embedding
+        # we expect that it will mix well by combining in the channel direction instead of adding
+
+        cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
+        cx = self.mid(cx)
+
+        if self.dropout is not None and self.training:
+            cx = torch.nn.functional.dropout(cx, p=self.dropout)
+
+        cx = self.up(cx) * self.multiplier
+
+        # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
+        if self.batch_cond_only:
+            zx = torch.zeros_like(x)
+            zx[1::2] += cx
+            cx = zx
+
+        x = self.org_forward(x + cx)  # ここで元のモジュールを呼び出す / call the original module here
+        return x
+
+
+class ControlNetLLLite(torch.nn.Module):
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+
+    def __init__(
+        self,
+        unet: sdxl_original_unet.SdxlUNet2DConditionModel,
+        cond_emb_dim: int = 16,
+        mlp_dim: int = 16,
+        dropout: Optional[float] = None,
+        varbose: Optional[bool] = False,
+        multiplier: Optional[float] = 1.0,
+    ) -> None:
+        super().__init__()
+        # self.unets = [unet]
+
+        def create_modules(
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+            module_class: Type[object],
+        ) -> List[torch.nn.Module]:
+            prefix = "lllite_unet"
+
+            modules = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "Linear"
+                        is_conv2d = child_module.__class__.__name__ == "Conv2d"
+                        if is_linear or (is_conv2d and not SKIP_CONV2D):
+                            # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
+                            # block index to depth: depth is using to calculate conditioning size and channels
+                            block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
+                            index1 = int(index1)
+                            if block_name == "input_blocks":
+                                if SKIP_INPUT_BLOCKS:
+                                    continue
+                                depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
+                            elif block_name == "middle_block":
+                                depth = 3
+                            elif block_name == "output_blocks":
+                                if SKIP_OUTPUT_BLOCKS:
+                                    continue
+                                depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
+                                if int(index2) >= 2:
+                                    depth -= 1
+                            else:
+                                raise NotImplementedError()
+
+                            lllite_name = prefix + "." + name + "." + child_name
+                            lllite_name = lllite_name.replace(".", "_")
+
+                            if TRANSFORMER_MAX_BLOCK_INDEX is not None:
+                                p = lllite_name.find("transformer_blocks")
+                                if p >= 0:
+                                    tf_index = int(lllite_name[p:].split("_")[2])
+                                    if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
+                                        continue
+
+                            #  time embは適用外とする
+                            # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
+                            # time emb is not applied
+                            # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
+                            if "emb_layers" in lllite_name or (
+                                "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
+                            ):
+                                continue
+
+                            if ATTN1_2_ONLY:
+                                if not ("attn1" in lllite_name or "attn2" in lllite_name):
+                                    continue
+                                if ATTN_QKV_ONLY:
+                                    if "to_out" in lllite_name:
+                                        continue
+
+                            if ATTN1_ETC_ONLY:
+                                if "proj_out" in lllite_name:
+                                    pass
+                                elif "attn1" in lllite_name and (
+                                    "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
+                                ):
+                                    pass
+                                elif "ff_net_2" in lllite_name:
+                                    pass
+                                else:
+                                    continue
+
+                            module = module_class(
+                                depth,
+                                cond_emb_dim,
+                                lllite_name,
+                                child_module,
+                                mlp_dim,
+                                dropout=dropout,
+                                multiplier=multiplier,
+                            )
+                            modules.append(module)
+            print(f"Returning {len(modules)} modules for llite net")
+            return modules
+
+        target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
+        if not TRANSFORMER_ONLY:
+            target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        # create module instances
+        self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
+        print(f"created ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
+
+    def forward(self, x):
+        return x  # dummy
+
+    def set_cond_image(self, cond_image):
+        r"""
+        中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
+        / call the model inside, so if necessary, surround it with torch.no_grad()
+        """
+        for module in self.unet_modules:
+            module.set_cond_image(cond_image)
+
+    def set_batch_cond_only(self, cond_only, zeros):
+        for module in self.unet_modules:
+            module.set_batch_cond_only(cond_only, zeros)
+
+    def set_multiplier(self, multiplier):
+        for module in self.unet_modules:
+            module.multiplier = multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def apply_to(self):
+        print("applying LLLite for U-Net...")
+        for module in self.unet_modules:
+            module.apply_to()
+            self.add_module(module.lllite_name, module)
+
+    # マージできるかどうかを返す
+    def is_mergeable(self):
+        return False
+
+    def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
+        raise NotImplementedError()
+
+    def enable_gradient_checkpointing(self):
+        # not supported
+        pass
+
+    def prepare_optimizer_params(self):
+        self.requires_grad_(True)
+        return self.parameters()
+
+    def prepare_grad_etc(self):
+        self.requires_grad_(True)
+
+    def on_epoch_start(self):
+        self.train()
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+
+if __name__ == "__main__":
+    # デバッグ用 / for debug
+
+    # sdxl_original_unet.USE_REENTRANT = False
+
+    # test shape etc
+    print("create unet")
+    unet = sdxl_original_unet.SdxlUNet2DConditionModel()
+    unet.to("cuda").to(torch.float16)
+
+    print("create ControlNet-LLLite")
+    control_net = ControlNetLLLite(unet, 32, 64)
+    control_net.apply_to()
+    control_net.to("cuda")
+
+    print(control_net)
+
+    # print number of parameters
+    print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
+
+    input()
+
+    unet.set_use_memory_efficient_attention(True, False)
+    unet.set_gradient_checkpointing(True)
+    unet.train()  # for gradient checkpointing
+
+    control_net.train()
+
+    # # visualize
+    # import torchviz
+    # print("run visualize")
+    # controlnet.set_control(conditioning_image)
+    # output = unet(x, t, ctx, y)
+    # print("make_dot")
+    # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
+    # print("render")
+    # image.format = "svg" # "png"
+    # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
+    # input()
+
+    import bitsandbytes
+
+    optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
+
+    scaler = torch.cuda.amp.GradScaler(enabled=True)
+
+    print("start training")
+    steps = 10
+
+    sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
+    for step in range(steps):
+        print(f"step {step}")
+
+        batch_size = 1
+        conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
+        x = torch.randn(batch_size, 4, 128, 128).cuda()
+        t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
+        ctx = torch.randn(batch_size, 77, 2048).cuda()
+        y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
+
+        with torch.cuda.amp.autocast(enabled=True):
+            control_net.set_cond_image(conditioning_image)
+
+            output = unet(x, t, ctx, y)
+            target = torch.randn_like(output)
+            loss = torch.nn.functional.mse_loss(output, target)
+
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad(set_to_none=True)
+        print(sample_param)
+
+    # from safetensors.torch import save_file
+
+    # save_file(control_net.state_dict(), "logs/control_net.safetensors")
diff --git a/external/llite/networks/control_net_lllite_for_train.py b/external/llite/networks/control_net_lllite_for_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0268800159f36b607b2d0457f561d43fa816ceb6
--- /dev/null
+++ b/external/llite/networks/control_net_lllite_for_train.py
@@ -0,0 +1,502 @@
+# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
+# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
+
+import os
+import re
+from typing import Optional, List, Type
+import torch
+from library import sdxl_original_unet
+
+
+# input_blocksに適用するかどうか / if True, input_blocks are not applied
+SKIP_INPUT_BLOCKS = False
+
+# output_blocksに適用するかどうか / if True, output_blocks are not applied
+SKIP_OUTPUT_BLOCKS = True
+
+# conv2dに適用するかどうか / if True, conv2d are not applied
+SKIP_CONV2D = False
+
+# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
+# if True, only transformer_blocks are applied, and ResBlocks are not applied
+TRANSFORMER_ONLY = True  # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
+
+# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
+ATTN1_2_ONLY = True
+
+# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
+ATTN_QKV_ONLY = True
+
+# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
+# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
+ATTN1_ETC_ONLY = False  # True
+
+# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
+# max index of transformer_blocks. if None, apply to all transformer_blocks
+TRANSFORMER_MAX_BLOCK_INDEX = None
+
+ORIGINAL_LINEAR = torch.nn.Linear
+ORIGINAL_CONV2D = torch.nn.Conv2d
+
+
+def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
+    # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
+    # conditioning1 embeds conditioning image. it is not called for each timestep
+    modules = []
+    modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))  # to latent (from VAE) size
+    if depth == 1:
+        modules.append(torch.nn.ReLU(inplace=True))
+        modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+    elif depth == 2:
+        modules.append(torch.nn.ReLU(inplace=True))
+        modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
+    elif depth == 3:
+        # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
+        modules.append(torch.nn.ReLU(inplace=True))
+        modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
+        modules.append(torch.nn.ReLU(inplace=True))
+        modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+
+    module.lllite_conditioning1 = torch.nn.Sequential(*modules)
+
+    # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
+    # midでconditioning image embeddingと入力を結合する
+    # upで元の次元数に戻す
+    # これらはtimestepごとに呼ばれる
+    # reduce the number of input dimensions with down. inspired by LoRA
+    # combine conditioning image embedding and input with mid
+    # restore to the original dimension with up
+    # these are called for each timestep
+
+    module.lllite_down = torch.nn.Sequential(
+        ORIGINAL_LINEAR(in_dim, mlp_dim),
+        torch.nn.ReLU(inplace=True),
+    )
+    module.lllite_mid = torch.nn.Sequential(
+        ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
+        torch.nn.ReLU(inplace=True),
+    )
+    module.lllite_up = torch.nn.Sequential(
+        ORIGINAL_LINEAR(mlp_dim, in_dim),
+    )
+
+    # Zero-Convにする / set to Zero-Conv
+    torch.nn.init.zeros_(module.lllite_up[0].weight)  # zero conv
+
+
+class LLLiteLinear(ORIGINAL_LINEAR):
+    def __init__(self, in_features: int, out_features: int, **kwargs):
+        super().__init__(in_features, out_features, **kwargs)
+        self.enabled = False
+
+    def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
+        self.enabled = True
+        self.lllite_name = name
+        self.cond_emb_dim = cond_emb_dim
+        self.dropout = dropout
+        self.multiplier = multiplier  # ignored
+
+        in_dim = self.in_features
+        add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
+
+        self.cond_image = None
+        self.cond_emb = None
+
+    def set_cond_image(self, cond_image):
+        self.cond_image = cond_image
+        self.cond_emb = None
+
+    def forward(self, x):
+        if not self.enabled:
+            return super().forward(x)
+
+        if self.cond_emb is None:
+            self.cond_emb = self.lllite_conditioning1(self.cond_image)
+        cx = self.cond_emb
+
+        # reshape / b,c,h,w -> b,h*w,c
+        n, c, h, w = cx.shape
+        cx = cx.view(n, c, h * w).permute(0, 2, 1)
+
+        cx = torch.cat([cx, self.lllite_down(x)], dim=2)
+        cx = self.lllite_mid(cx)
+
+        if self.dropout is not None and self.training:
+            cx = torch.nn.functional.dropout(cx, p=self.dropout)
+
+        cx = self.lllite_up(cx) * self.multiplier
+
+        x = super().forward(x + cx)  # ここで元のモジュールを呼び出す / call the original module here
+        return x
+
+
+class LLLiteConv2d(ORIGINAL_CONV2D):
+    def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
+        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
+        self.enabled = False
+
+    def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
+        self.enabled = True
+        self.lllite_name = name
+        self.cond_emb_dim = cond_emb_dim
+        self.dropout = dropout
+        self.multiplier = multiplier  # ignored
+
+        in_dim = self.in_channels
+        add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
+
+        self.cond_image = None
+        self.cond_emb = None
+
+    def set_cond_image(self, cond_image):
+        self.cond_image = cond_image
+        self.cond_emb = None
+
+    def forward(self, x):  # , cond_image=None):
+        if not self.enabled:
+            return super().forward(x)
+
+        if self.cond_emb is None:
+            self.cond_emb = self.lllite_conditioning1(self.cond_image)
+        cx = self.cond_emb
+
+        cx = torch.cat([cx, self.down(x)], dim=1)
+        cx = self.mid(cx)
+
+        if self.dropout is not None and self.training:
+            cx = torch.nn.functional.dropout(cx, p=self.dropout)
+
+        cx = self.up(cx) * self.multiplier
+
+        x = super().forward(x + cx)  # ここで元のモジュールを呼び出す / call the original module here
+        return x
+
+
+class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+    LLLITE_PREFIX = "lllite_unet"
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def apply_lllite(
+        self,
+        cond_emb_dim: int = 16,
+        mlp_dim: int = 16,
+        dropout: Optional[float] = None,
+        varbose: Optional[bool] = False,
+        multiplier: Optional[float] = 1.0,
+    ) -> None:
+        def apply_to_modules(
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+        ) -> List[torch.nn.Module]:
+            prefix = "lllite_unet"
+
+            modules = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "LLLiteLinear"
+                        is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
+
+                        if is_linear or (is_conv2d and not SKIP_CONV2D):
+                            # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
+                            # block index to depth: depth is using to calculate conditioning size and channels
+                            block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
+                            index1 = int(index1)
+                            if block_name == "input_blocks":
+                                if SKIP_INPUT_BLOCKS:
+                                    continue
+                                depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
+                            elif block_name == "middle_block":
+                                depth = 3
+                            elif block_name == "output_blocks":
+                                if SKIP_OUTPUT_BLOCKS:
+                                    continue
+                                depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
+                                if int(index2) >= 2:
+                                    depth -= 1
+                            else:
+                                raise NotImplementedError()
+
+                            lllite_name = prefix + "." + name + "." + child_name
+                            lllite_name = lllite_name.replace(".", "_")
+
+                            if TRANSFORMER_MAX_BLOCK_INDEX is not None:
+                                p = lllite_name.find("transformer_blocks")
+                                if p >= 0:
+                                    tf_index = int(lllite_name[p:].split("_")[2])
+                                    if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
+                                        continue
+
+                            #  time embは適用外とする
+                            # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
+                            # time emb is not applied
+                            # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
+                            if "emb_layers" in lllite_name or (
+                                "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
+                            ):
+                                continue
+
+                            if ATTN1_2_ONLY:
+                                if not ("attn1" in lllite_name or "attn2" in lllite_name):
+                                    continue
+                                if ATTN_QKV_ONLY:
+                                    if "to_out" in lllite_name:
+                                        continue
+
+                            if ATTN1_ETC_ONLY:
+                                if "proj_out" in lllite_name:
+                                    pass
+                                elif "attn1" in lllite_name and (
+                                    "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
+                                ):
+                                    pass
+                                elif "ff_net_2" in lllite_name:
+                                    pass
+                                else:
+                                    continue
+
+                            child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
+                            modules.append(child_module)
+
+            return modules
+
+        target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
+        if not TRANSFORMER_ONLY:
+            target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        # create module instances
+        self.lllite_modules = apply_to_modules(self, target_modules)
+        print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
+
+    # def prepare_optimizer_params(self):
+    def prepare_params(self):
+        train_params = []
+        non_train_params = []
+        for name, p in self.named_parameters():
+            if "lllite" in name:
+                train_params.append(p)
+            else:
+                non_train_params.append(p)
+        print(f"count of trainable parameters: {len(train_params)}")
+        print(f"count of non-trainable parameters: {len(non_train_params)}")
+
+        for p in non_train_params:
+            p.requires_grad_(False)
+
+        # without this, an error occurs in the optimizer
+        #       RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
+        non_train_params[0].requires_grad_(True)
+
+        for p in train_params:
+            p.requires_grad_(True)
+
+        return train_params
+
+    # def prepare_grad_etc(self):
+    #     self.requires_grad_(True)
+
+    # def on_epoch_start(self):
+    #     self.train()
+
+    def get_trainable_params(self):
+        return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
+
+    def save_lllite_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        org_state_dict = self.state_dict()
+
+        # copy LLLite keys from org_state_dict to state_dict with key conversion
+        state_dict = {}
+        for key in org_state_dict.keys():
+            # split with ".lllite"
+            pos = key.find(".lllite")
+            if pos < 0:
+                continue
+            lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
+            lllite_key = lllite_key.replace(".", "_") + key[pos:]
+            lllite_key = lllite_key.replace(".lllite_", ".")
+            state_dict[lllite_key] = org_state_dict[key]
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+    def load_lllite_weights(self, file, non_lllite_unet_sd=None):
+        r"""
+        LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。
+        この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
+
+        If you do not want to load LLLite weights (use initialized values), specify None for file.
+        In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
+        """
+        if not file:
+            state_dict = self.state_dict()
+            for key in non_lllite_unet_sd:
+                if key in state_dict:
+                    state_dict[key] = non_lllite_unet_sd[key]
+            info = self.load_state_dict(state_dict, False)
+            return info
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        # module_name = module_name.replace("_block", "@blocks")
+        # module_name = module_name.replace("_layer", "@layer")
+        # module_name = module_name.replace("to_", "to@")
+        # module_name = module_name.replace("time_embed", "time@embed")
+        # module_name = module_name.replace("label_emb", "label@emb")
+        # module_name = module_name.replace("skip_connection", "skip@connection")
+        # module_name = module_name.replace("proj_in", "proj@in")
+        # module_name = module_name.replace("proj_out", "proj@out")
+        pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
+
+        # convert to lllite with U-Net state dict
+        state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
+        for key in weights_sd.keys():
+            # split with "."
+            pos = key.find(".")
+            if pos < 0:
+                continue
+
+            module_name = key[:pos]
+            weight_name = key[pos + 1 :]  # exclude "."
+            module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
+
+            # これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
+            # module_name = module_name.replace("_", ".")
+
+            # ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
+            matches = pattern.findall(module_name)
+            if matches is not None:
+                for m in matches:
+                    print(module_name, m)
+                    module_name = module_name.replace(m, m.replace("_", "@"))
+            module_name = module_name.replace("_", ".")
+            module_name = module_name.replace("@", "_")
+
+            lllite_key = module_name + ".lllite_" + weight_name
+
+            state_dict[lllite_key] = weights_sd[key]
+
+        info = self.load_state_dict(state_dict, False)
+        return info
+
+    def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
+        for m in self.lllite_modules:
+            m.set_cond_image(cond_image)
+        return super().forward(x, timesteps, context, y, **kwargs)
+
+
+def replace_unet_linear_and_conv2d():
+    print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
+    sdxl_original_unet.torch.nn.Linear = LLLiteLinear
+    sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
+
+
+if __name__ == "__main__":
+    # デバッグ用 / for debug
+
+    # sdxl_original_unet.USE_REENTRANT = False
+    replace_unet_linear_and_conv2d()
+
+    # test shape etc
+    print("create unet")
+    unet = SdxlUNet2DConditionModelControlNetLLLite()
+
+    print("enable ControlNet-LLLite")
+    unet.apply_lllite(32, 64, None, False, 1.0)
+    unet.to("cuda")  # .to(torch.float16)
+
+    # from safetensors.torch import load_file
+
+    # model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
+    # unet_sd = {}
+
+    # # copy U-Net keys from unet_state_dict to state_dict
+    # prefix = "model.diffusion_model."
+    # for key in model_sd.keys():
+    #     if key.startswith(prefix):
+    #         converted_key = key[len(prefix) :]
+    #         unet_sd[converted_key] = model_sd[key]
+
+    # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
+    # print(info)
+
+    # print(unet)
+
+    # print number of parameters
+    params = unet.prepare_params()
+    print("number of parameters", sum(p.numel() for p in params))
+    # print("type any key to continue")
+    # input()
+
+    unet.set_use_memory_efficient_attention(True, False)
+    unet.set_gradient_checkpointing(True)
+    unet.train()  # for gradient checkpointing
+
+    # # visualize
+    # import torchviz
+    # print("run visualize")
+    # controlnet.set_control(conditioning_image)
+    # output = unet(x, t, ctx, y)
+    # print("make_dot")
+    # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
+    # print("render")
+    # image.format = "svg" # "png"
+    # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
+    # input()
+
+    import bitsandbytes
+
+    optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
+
+    scaler = torch.cuda.amp.GradScaler(enabled=True)
+
+    print("start training")
+    steps = 10
+    batch_size = 1
+
+    sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
+    for step in range(steps):
+        print(f"step {step}")
+
+        conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
+        x = torch.randn(batch_size, 4, 128, 128).cuda()
+        t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
+        ctx = torch.randn(batch_size, 77, 2048).cuda()
+        y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
+
+        with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
+            output = unet(x, t, ctx, y, conditioning_image)
+            target = torch.randn_like(output)
+            loss = torch.nn.functional.mse_loss(output, target)
+
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad(set_to_none=True)
+        print(sample_param)
+
+    # from safetensors.torch import save_file
+
+    # print("save weights")
+    # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
diff --git a/external/llite/networks/dylora.py b/external/llite/networks/dylora.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a55d1988ac4bc35f1d1657792efcfbaaccea9e
--- /dev/null
+++ b/external/llite/networks/dylora.py
@@ -0,0 +1,450 @@
+# some codes are copied from:
+# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
+
+# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
+# Changes made to the original code:
+# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
+#  ------------------------------------------------------------------------------------------
+#  Copyright (c) Microsoft Corporation. All rights reserved.
+#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+#  ------------------------------------------------------------------------------------------
+
+import math
+import os
+import random
+from typing import List, Tuple, Union
+import torch
+from torch import nn
+
+
+class DyLoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    # NOTE: support dropout in future
+    def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
+        super().__init__()
+        self.lora_name = lora_name
+        self.lora_dim = lora_dim
+        self.unit = unit
+        assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
+
+        if org_module.__class__.__name__ == "Conv2d":
+            in_dim = org_module.in_channels
+            out_dim = org_module.out_channels
+        else:
+            in_dim = org_module.in_features
+            out_dim = org_module.out_features
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
+
+        self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
+        self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
+
+        if self.is_conv2d and self.is_conv2d_3x3:
+            kernel_size = org_module.kernel_size
+            self.stride = org_module.stride
+            self.padding = org_module.padding
+            self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
+            self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
+        else:
+            self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
+            self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
+
+        # same as microsoft's
+        for lora in self.lora_A:
+            torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
+        for lora in self.lora_B:
+            torch.nn.init.zeros_(lora)
+
+        self.multiplier = multiplier
+        self.org_module = org_module  # remove in applying
+
+    def apply_to(self):
+        self.org_forward = self.org_module.forward
+        self.org_module.forward = self.forward
+        del self.org_module
+
+    def forward(self, x):
+        result = self.org_forward(x)
+
+        # specify the dynamic rank
+        trainable_rank = random.randint(0, self.lora_dim - 1)
+        trainable_rank = trainable_rank - trainable_rank % self.unit  # make sure the rank is a multiple of unit
+
+        # 一部のパラメータを固定して、残りのパラメータを学習する
+        for i in range(0, trainable_rank):
+            self.lora_A[i].requires_grad = False
+            self.lora_B[i].requires_grad = False
+        for i in range(trainable_rank, trainable_rank + self.unit):
+            self.lora_A[i].requires_grad = True
+            self.lora_B[i].requires_grad = True
+        for i in range(trainable_rank + self.unit, self.lora_dim):
+            self.lora_A[i].requires_grad = False
+            self.lora_B[i].requires_grad = False
+
+        lora_A = torch.cat(tuple(self.lora_A), dim=0)
+        lora_B = torch.cat(tuple(self.lora_B), dim=1)
+
+        # calculate with lora_A and lora_B
+        if self.is_conv2d_3x3:
+            ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
+            ab = torch.nn.functional.conv2d(ab, lora_B)
+        else:
+            ab = x
+            if self.is_conv2d:
+                ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2)  # (N, C, H, W) -> (N, H*W, C)
+
+            ab = torch.nn.functional.linear(ab, lora_A)
+            ab = torch.nn.functional.linear(ab, lora_B)
+
+            if self.is_conv2d:
+                ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:])  # (N, H*W, C) -> (N, C, H, W)
+
+        # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
+        result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
+
+        # NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
+        return result
+
+    def state_dict(self, destination=None, prefix="", keep_vars=False):
+        # state dictを通常のLoRAと同じにする:
+        # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
+        sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
+
+        lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
+        if self.is_conv2d and not self.is_conv2d_3x3:
+            lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
+
+        lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
+        if self.is_conv2d and not self.is_conv2d_3x3:
+            lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
+
+        sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
+        sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
+
+        i = 0
+        while True:
+            key_a = f"{self.lora_name}.lora_A.{i}"
+            key_b = f"{self.lora_name}.lora_B.{i}"
+            if key_a in sd:
+                sd.pop(key_a)
+                sd.pop(key_b)
+            else:
+                break
+            i += 1
+        return sd
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+        # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
+        lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
+        lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
+
+        if lora_A_weight is None or lora_B_weight is None:
+            if strict:
+                raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
+            else:
+                return
+
+        if self.is_conv2d and not self.is_conv2d_3x3:
+            lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
+            lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
+
+        state_dict.update(
+            {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
+        )
+        state_dict.update(
+            {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
+        )
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+
+def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
+    if network_dim is None:
+        network_dim = 4  # default
+    if network_alpha is None:
+        network_alpha = 1.0
+
+    # extract dim/alpha for conv2d, and block dim
+    conv_dim = kwargs.get("conv_dim", None)
+    conv_alpha = kwargs.get("conv_alpha", None)
+    unit = kwargs.get("unit", None)
+    if conv_dim is not None:
+        conv_dim = int(conv_dim)
+        assert conv_dim == network_dim, "conv_dim must be same as network_dim"
+        if conv_alpha is None:
+            conv_alpha = 1.0
+        else:
+            conv_alpha = float(conv_alpha)
+    if unit is not None:
+        unit = int(unit)
+    else:
+        unit = 1
+
+    network = DyLoRANetwork(
+        text_encoder,
+        unet,
+        multiplier=multiplier,
+        lora_dim=network_dim,
+        alpha=network_alpha,
+        apply_to_conv=conv_dim is not None,
+        unit=unit,
+        varbose=True,
+    )
+    return network
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
+    if weights_sd is None:
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file, safe_open
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+    # get dim/alpha mapping
+    modules_dim = {}
+    modules_alpha = {}
+    for key, value in weights_sd.items():
+        if "." not in key:
+            continue
+
+        lora_name = key.split(".")[0]
+        if "alpha" in key:
+            modules_alpha[lora_name] = value
+        elif "lora_down" in key:
+            dim = value.size()[0]
+            modules_dim[lora_name] = dim
+            # print(lora_name, value.size(), dim)
+
+    # support old LoRA without alpha
+    for key in modules_dim.keys():
+        if key not in modules_alpha:
+            modules_alpha = modules_dim[key]
+
+    module_class = DyLoRAModule
+
+    network = DyLoRANetwork(
+        text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
+    )
+    return network, weights_sd
+
+
+class DyLoRANetwork(torch.nn.Module):
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+    LORA_PREFIX_UNET = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+
+    def __init__(
+        self,
+        text_encoder,
+        unet,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+        apply_to_conv=False,
+        modules_dim=None,
+        modules_alpha=None,
+        unit=1,
+        module_class=DyLoRAModule,
+        varbose=False,
+    ) -> None:
+        super().__init__()
+        self.multiplier = multiplier
+
+        self.lora_dim = lora_dim
+        self.alpha = alpha
+        self.apply_to_conv = apply_to_conv
+
+        if modules_dim is not None:
+            print(f"create LoRA network from weights")
+        else:
+            print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
+            if self.apply_to_conv:
+                print(f"apply LoRA to Conv2d with kernel size (3,3).")
+
+        # create module instances
+        def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
+            prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
+            loras = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "Linear"
+                        is_conv2d = child_module.__class__.__name__ == "Conv2d"
+                        is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+                        if is_linear or is_conv2d:
+                            lora_name = prefix + "." + name + "." + child_name
+                            lora_name = lora_name.replace(".", "_")
+
+                            dim = None
+                            alpha = None
+                            if modules_dim is not None:
+                                if lora_name in modules_dim:
+                                    dim = modules_dim[lora_name]
+                                    alpha = modules_alpha[lora_name]
+                            else:
+                                if is_linear or is_conv2d_1x1 or apply_to_conv:
+                                    dim = self.lora_dim
+                                    alpha = self.alpha
+
+                            if dim is None or dim == 0:
+                                continue
+
+                            # dropout and fan_in_fan_out is default
+                            lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
+                            loras.append(lora)
+            return loras
+
+        self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+        print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+        # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
+        target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
+        if modules_dim is not None or self.apply_to_conv:
+            target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        self.unet_loras = create_modules(True, unet, target_modules)
+        print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+    def set_multiplier(self, multiplier):
+        self.multiplier = multiplier
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.multiplier = self.multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.apply_to()
+            self.add_module(lora.lora_name, lora)
+
+    """
+    def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
+        apply_text_encoder = apply_unet = False
+        for key in weights_sd.keys():
+            if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
+                apply_text_encoder = True
+            elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
+                apply_unet = True
+
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            sd_for_lora = {}
+            for key in weights_sd.keys():
+                if key.startswith(lora.lora_name):
+                    sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+            lora.merge_to(sd_for_lora, dtype, device)
+
+        print(f"weights are merged")
+    """
+
+    def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+        self.requires_grad_(True)
+        all_params = []
+
+        def enumerate_params(loras):
+            params = []
+            for lora in loras:
+                params.extend(lora.parameters())
+            return params
+
+        if self.text_encoder_loras:
+            param_data = {"params": enumerate_params(self.text_encoder_loras)}
+            if text_encoder_lr is not None:
+                param_data["lr"] = text_encoder_lr
+            all_params.append(param_data)
+
+        if self.unet_loras:
+            param_data = {"params": enumerate_params(self.unet_loras)}
+            if unet_lr is not None:
+                param_data["lr"] = unet_lr
+            all_params.append(param_data)
+
+        return all_params
+
+    def enable_gradient_checkpointing(self):
+        # not supported
+        pass
+
+    def prepare_grad_etc(self, text_encoder, unet):
+        self.requires_grad_(True)
+
+    def on_epoch_start(self, text_encoder, unet):
+        self.train()
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+            from library import train_util
+
+            # Precalculate model hashes to save time on indexing
+            if metadata is None:
+                metadata = {}
+            model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+            metadata["sshs_model_hash"] = model_hash
+            metadata["sshs_legacy_hash"] = legacy_hash
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+    # mask is a tensor with values from 0 to 1
+    def set_region(self, sub_prompt_index, is_last_network, mask):
+        pass
+
+    def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
+        pass
diff --git a/external/llite/networks/extract_lora_from_dylora.py b/external/llite/networks/extract_lora_from_dylora.py
new file mode 100644
index 0000000000000000000000000000000000000000..0abee98368911575c6e84974ab7ca7a4cdc97e0e
--- /dev/null
+++ b/external/llite/networks/extract_lora_from_dylora.py
@@ -0,0 +1,125 @@
+# Convert LoRA to different rank approximation (should only be used to go to lower rank)
+# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
+# Thanks to cloneofsimo
+
+import argparse
+import math
+import os
+import torch
+from safetensors.torch import load_file, save_file, safe_open
+from tqdm import tqdm
+from library import train_util, model_util
+import numpy as np
+
+
+def load_state_dict(file_name):
+    if model_util.is_safetensors(file_name):
+        sd = load_file(file_name)
+        with safe_open(file_name, framework="pt") as f:
+            metadata = f.metadata()
+    else:
+        sd = torch.load(file_name, map_location="cpu")
+        metadata = None
+
+    return sd, metadata
+
+
+def save_to_file(file_name, model, metadata):
+    if model_util.is_safetensors(file_name):
+        save_file(model, file_name, metadata)
+    else:
+        torch.save(model, file_name)
+
+
+def split_lora_model(lora_sd, unit):
+    max_rank = 0
+
+    # Extract loaded lora dim and alpha
+    for key, value in lora_sd.items():
+        if "lora_down" in key:
+            rank = value.size()[0]
+            if rank > max_rank:
+                max_rank = rank
+    print(f"Max rank: {max_rank}")
+
+    rank = unit
+    split_models = []
+    new_alpha = None
+    while rank < max_rank:
+        print(f"Splitting rank {rank}")
+        new_sd = {}
+        for key, value in lora_sd.items():
+            if "lora_down" in key:
+                new_sd[key] = value[:rank].contiguous()
+            elif "lora_up" in key:
+                new_sd[key] = value[:, :rank].contiguous()
+            else:
+                # なぜかscaleするとおかしくなる……
+                # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
+                # scale = math.sqrt(this_rank / rank)  # rank is > unit
+                # print(key, value.size(), this_rank, rank, value, scale)
+                # new_alpha = value * scale  # always same
+                # new_sd[key] = new_alpha
+                new_sd[key] = value
+
+        split_models.append((new_sd, rank, new_alpha))
+        rank += unit
+
+    return max_rank, split_models
+
+
+def split(args):
+    print("loading Model...")
+    lora_sd, metadata = load_state_dict(args.model)
+
+    print("Splitting Model...")
+    original_rank, split_models = split_lora_model(lora_sd, args.unit)
+
+    comment = metadata.get("ss_training_comment", "")
+    for state_dict, new_rank, new_alpha in split_models:
+        # update metadata
+        if metadata is None:
+            new_metadata = {}
+        else:
+            new_metadata = metadata.copy()
+
+        new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
+        new_metadata["ss_network_dim"] = str(new_rank)
+        # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
+
+        model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+        metadata["sshs_model_hash"] = model_hash
+        metadata["sshs_legacy_hash"] = legacy_hash
+
+        filename, ext = os.path.splitext(args.save_to)
+        model_file_name = filename + f"-{new_rank:04d}{ext}"
+
+        print(f"saving model to: {model_file_name}")
+        save_to_file(model_file_name, state_dict, new_metadata)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
+    parser.add_argument(
+        "--save_to",
+        type=str,
+        default=None,
+        help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        default=None,
+        help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
+    )
+
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    split(args)
diff --git a/external/llite/networks/extract_lora_from_models.py b/external/llite/networks/extract_lora_from_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..6357df55d0a3fa72d32f69fe3bd6246a6315292d
--- /dev/null
+++ b/external/llite/networks/extract_lora_from_models.py
@@ -0,0 +1,296 @@
+# extract approximating LoRA by svd from two SD models
+# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
+# Thanks to cloneofsimo!
+
+import argparse
+import json
+import os
+import time
+import torch
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+from library import sai_model_spec, model_util, sdxl_model_util
+import lora
+
+
+# CLAMP_QUANTILE = 0.99
+# MIN_DIFF = 1e-1
+
+
+def save_to_file(file_name, model, state_dict, dtype):
+    if dtype is not None:
+        for key in list(state_dict.keys()):
+            if type(state_dict[key]) == torch.Tensor:
+                state_dict[key] = state_dict[key].to(dtype)
+
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        save_file(model, file_name)
+    else:
+        torch.save(model, file_name)
+
+
+def svd(
+    model_org=None,
+    model_tuned=None,
+    save_to=None,
+    dim=4,
+    v2=None,
+    sdxl=None,
+    conv_dim=None,
+    v_parameterization=None,
+    device=None,
+    save_precision=None,
+    clamp_quantile=0.99,
+    min_diff=0.01,
+    no_metadata=False,
+):
+    def str_to_dtype(p):
+        if p == "float":
+            return torch.float
+        if p == "fp16":
+            return torch.float16
+        if p == "bf16":
+            return torch.bfloat16
+        return None
+
+    assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
+    if v_parameterization is None:
+        v_parameterization = v2
+
+    save_dtype = str_to_dtype(save_precision)
+
+    # load models
+    if not sdxl:
+        print(f"loading original SD model : {model_org}")
+        text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
+        text_encoders_o = [text_encoder_o]
+        print(f"loading tuned SD model : {model_tuned}")
+        text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
+        text_encoders_t = [text_encoder_t]
+        model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
+    else:
+        print(f"loading original SDXL model : {model_org}")
+        text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
+            sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
+        )
+        text_encoders_o = [text_encoder_o1, text_encoder_o2]
+        print(f"loading original SDXL model : {model_tuned}")
+        text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
+            sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
+        )
+        text_encoders_t = [text_encoder_t1, text_encoder_t2]
+        model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
+
+    # create LoRA network to extract weights: Use dim (rank) as alpha
+    if conv_dim is None:
+        kwargs = {}
+    else:
+        kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
+
+    lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
+    lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
+    assert len(lora_network_o.text_encoder_loras) == len(
+        lora_network_t.text_encoder_loras
+    ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
+
+    # get diffs
+    diffs = {}
+    text_encoder_different = False
+    for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
+        lora_name = lora_o.lora_name
+        module_o = lora_o.org_module
+        module_t = lora_t.org_module
+        diff = module_t.weight - module_o.weight
+
+        # Text Encoder might be same
+        if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
+            text_encoder_different = True
+            print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
+
+        diff = diff.float()
+        diffs[lora_name] = diff
+
+    if not text_encoder_different:
+        print("Text encoder is same. Extract U-Net only.")
+        lora_network_o.text_encoder_loras = []
+        diffs = {}
+
+    for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
+        lora_name = lora_o.lora_name
+        module_o = lora_o.org_module
+        module_t = lora_t.org_module
+        diff = module_t.weight - module_o.weight
+        diff = diff.float()
+
+        if args.device:
+            diff = diff.to(args.device)
+
+        diffs[lora_name] = diff
+
+    # make LoRA with svd
+    print("calculating by svd")
+    lora_weights = {}
+    with torch.no_grad():
+        for lora_name, mat in tqdm(list(diffs.items())):
+            # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
+            conv2d = len(mat.size()) == 4
+            kernel_size = None if not conv2d else mat.size()[2:4]
+            conv2d_3x3 = conv2d and kernel_size != (1, 1)
+
+            rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
+            out_dim, in_dim = mat.size()[0:2]
+
+            if device:
+                mat = mat.to(device)
+
+            # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
+            rank = min(rank, in_dim, out_dim)  # LoRA rank cannot exceed the original dim
+
+            if conv2d:
+                if conv2d_3x3:
+                    mat = mat.flatten(start_dim=1)
+                else:
+                    mat = mat.squeeze()
+
+            U, S, Vh = torch.linalg.svd(mat)
+
+            U = U[:, :rank]
+            S = S[:rank]
+            U = U @ torch.diag(S)
+
+            Vh = Vh[:rank, :]
+
+            dist = torch.cat([U.flatten(), Vh.flatten()])
+            hi_val = torch.quantile(dist, clamp_quantile)
+            low_val = -hi_val
+
+            U = U.clamp(low_val, hi_val)
+            Vh = Vh.clamp(low_val, hi_val)
+
+            if conv2d:
+                U = U.reshape(out_dim, rank, 1, 1)
+                Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
+
+            U = U.to("cpu").contiguous()
+            Vh = Vh.to("cpu").contiguous()
+
+            lora_weights[lora_name] = (U, Vh)
+
+    # make state dict for LoRA
+    lora_sd = {}
+    for lora_name, (up_weight, down_weight) in lora_weights.items():
+        lora_sd[lora_name + ".lora_up.weight"] = up_weight
+        lora_sd[lora_name + ".lora_down.weight"] = down_weight
+        lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
+
+    # load state dict to LoRA and save it
+    lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
+    lora_network_save.apply_to(text_encoders_o, unet_o)  # create internal module references for state_dict
+
+    info = lora_network_save.load_state_dict(lora_sd)
+    print(f"Loading extracted LoRA weights: {info}")
+
+    dir_name = os.path.dirname(save_to)
+    if dir_name and not os.path.exists(dir_name):
+        os.makedirs(dir_name, exist_ok=True)
+
+    # minimum metadata
+    net_kwargs = {}
+    if conv_dim is not None:
+        net_kwargs["conv_dim"] = str(conv_dim)
+        net_kwargs["conv_alpha"] = str(float(conv_dim))
+
+    metadata = {
+        "ss_v2": str(v2),
+        "ss_base_model_version": model_version,
+        "ss_network_module": "networks.lora",
+        "ss_network_dim": str(dim),
+        "ss_network_alpha": str(float(dim)),
+        "ss_network_args": json.dumps(net_kwargs),
+    }
+
+    if not no_metadata:
+        title = os.path.splitext(os.path.basename(save_to))[0]
+        sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
+        metadata.update(sai_metadata)
+
+    lora_network_save.save_weights(save_to, save_dtype, metadata)
+    print(f"LoRA weights are saved to: {save_to}")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
+    parser.add_argument(
+        "--v_parameterization",
+        action="store_true",
+        default=None,
+        help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
+    )
+    parser.add_argument(
+        "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
+    )
+    parser.add_argument(
+        "--save_precision",
+        type=str,
+        default=None,
+        choices=[None, "float", "fp16", "bf16"],
+        help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
+    )
+    parser.add_argument(
+        "--model_org",
+        type=str,
+        default=None,
+        required=True,
+        help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
+    )
+    parser.add_argument(
+        "--model_tuned",
+        type=str,
+        default=None,
+        required=True,
+        help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
+    )
+    parser.add_argument(
+        "--save_to",
+        type=str,
+        default=None,
+        required=True,
+        help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
+    )
+    parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
+    parser.add_argument(
+        "--conv_dim",
+        type=int,
+        default=None,
+        help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
+    )
+    parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
+    parser.add_argument(
+        "--clamp_quantile",
+        type=float,
+        default=0.99,
+        help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
+    )
+    parser.add_argument(
+        "--min_diff",
+        type=float,
+        default=0.01,
+        help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+        + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
+    )
+    parser.add_argument(
+        "--no_metadata",
+        action="store_true",
+        help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+        + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
+    )
+
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    svd(**vars(args))
diff --git a/external/llite/networks/lora.py b/external/llite/networks/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c75cd428f39481d831f8d4e316dd7f45f61ce32
--- /dev/null
+++ b/external/llite/networks/lora.py
@@ -0,0 +1,1225 @@
+# LoRA network module
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+
+import math
+import os
+from typing import Dict, List, Optional, Tuple, Type, Union
+from diffusers import AutoencoderKL
+from transformers import CLIPTextModel
+import numpy as np
+import torch
+import re
+
+
+RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
+
+
+class LoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+        dropout=None,
+        rank_dropout=None,
+        module_dropout=None,
+    ):
+        """if alpha == 0 or None, alpha is rank (no scaling)."""
+        super().__init__()
+        self.lora_name = lora_name
+
+        if org_module.__class__.__name__ == "Conv2d":
+            in_dim = org_module.in_channels
+            out_dim = org_module.out_channels
+        else:
+            in_dim = org_module.in_features
+            out_dim = org_module.out_features
+
+        # if limit_rank:
+        #   self.lora_dim = min(lora_dim, in_dim, out_dim)
+        #   if self.lora_dim != lora_dim:
+        #     print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
+        # else:
+        self.lora_dim = lora_dim
+
+        if org_module.__class__.__name__ == "Conv2d":
+            kernel_size = org_module.kernel_size
+            stride = org_module.stride
+            padding = org_module.padding
+            self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+            self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+        else:
+            self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+            self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
+
+        # same as microsoft's
+        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+        torch.nn.init.zeros_(self.lora_up.weight)
+
+        self.multiplier = multiplier
+        self.org_module = org_module  # remove in applying
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+    def apply_to(self):
+        self.org_forward = self.org_module.forward
+        self.org_module.forward = self.forward
+        del self.org_module
+
+    def forward(self, x):
+        org_forwarded = self.org_forward(x)
+
+        # module dropout
+        if self.module_dropout is not None and self.training:
+            if torch.rand(1) < self.module_dropout:
+                return org_forwarded
+
+        lx = self.lora_down(x)
+
+        # normal dropout
+        if self.dropout is not None and self.training:
+            lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+        # rank dropout
+        if self.rank_dropout is not None and self.training:
+            mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+            if len(lx.size()) == 3:
+                mask = mask.unsqueeze(1)  # for Text Encoder
+            elif len(lx.size()) == 4:
+                mask = mask.unsqueeze(-1).unsqueeze(-1)  # for Conv2d
+            lx = lx * mask
+
+            # scaling for rank dropout: treat as if the rank is changed
+            # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
+            scale = self.scale * (1.0 / (1.0 - self.rank_dropout))  # redundant for readability
+        else:
+            scale = self.scale
+
+        lx = self.lora_up(lx)
+
+        return org_forwarded + lx * self.multiplier * scale
+
+
+class LoRAInfModule(LoRAModule):
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+        **kwargs,
+    ):
+        # no dropout for inference
+        super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
+
+        self.org_module_ref = [org_module]  # 後から参照できるように
+        self.enabled = True
+
+        # check regional or not by lora_name
+        self.text_encoder = False
+        if lora_name.startswith("lora_te_"):
+            self.regional = False
+            self.use_sub_prompt = True
+            self.text_encoder = True
+        elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
+            self.regional = False
+            self.use_sub_prompt = True
+        elif "time_emb" in lora_name:
+            self.regional = False
+            self.use_sub_prompt = False
+        else:
+            self.regional = True
+            self.use_sub_prompt = False
+
+        self.network: LoRANetwork = None
+
+    def set_network(self, network):
+        self.network = network
+
+    # freezeしてマージする
+    def merge_to(self, sd, dtype, device):
+        # get up/down weight
+        up_weight = sd["lora_up.weight"].to(torch.float).to(device)
+        down_weight = sd["lora_down.weight"].to(torch.float).to(device)
+
+        # extract weight from org_module
+        org_sd = self.org_module.state_dict()
+        weight = org_sd["weight"].to(torch.float)
+
+        # merge weight
+        if len(weight.size()) == 2:
+            # linear
+            weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+        elif down_weight.size()[2:4] == (1, 1):
+            # conv2d 1x1
+            weight = (
+                weight
+                + self.multiplier
+                * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                * self.scale
+            )
+        else:
+            # conv2d 3x3
+            conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+            # print(conved.size(), weight.size(), module.stride, module.padding)
+            weight = weight + self.multiplier * conved * self.scale
+
+        # set weight to org_module
+        org_sd["weight"] = weight.to(dtype)
+        self.org_module.load_state_dict(org_sd)
+
+    # 復元できるマージのため、このモジュールのweightを返す
+    def get_weight(self, multiplier=None):
+        if multiplier is None:
+            multiplier = self.multiplier
+
+        # get up/down weight from module
+        up_weight = self.lora_up.weight.to(torch.float)
+        down_weight = self.lora_down.weight.to(torch.float)
+
+        # pre-calculated weight
+        if len(down_weight.size()) == 2:
+            # linear
+            weight = self.multiplier * (up_weight @ down_weight) * self.scale
+        elif down_weight.size()[2:4] == (1, 1):
+            # conv2d 1x1
+            weight = (
+                self.multiplier
+                * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                * self.scale
+            )
+        else:
+            # conv2d 3x3
+            conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+            weight = self.multiplier * conved * self.scale
+
+        return weight
+
+    def set_region(self, region):
+        self.region = region
+        self.region_mask = None
+
+    def default_forward(self, x):
+        # print("default_forward", self.lora_name, x.size())
+        return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
+
+    def forward(self, x):
+        if not self.enabled:
+            return self.org_forward(x)
+
+        if self.network is None or self.network.sub_prompt_index is None:
+            return self.default_forward(x)
+        if not self.regional and not self.use_sub_prompt:
+            return self.default_forward(x)
+
+        if self.regional:
+            return self.regional_forward(x)
+        else:
+            return self.sub_prompt_forward(x)
+
+    def get_mask_for_x(self, x):
+        # calculate size from shape of x
+        if len(x.size()) == 4:
+            h, w = x.size()[2:4]
+            area = h * w
+        else:
+            area = x.size()[1]
+
+        mask = self.network.mask_dic.get(area, None)
+        if mask is None:
+            # raise ValueError(f"mask is None for resolution {area}")
+            # emb_layers in SDXL doesn't have mask
+            # print(f"mask is None for resolution {area}, {x.size()}")
+            mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
+            return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
+        if len(x.size()) != 4:
+            mask = torch.reshape(mask, (1, -1, 1))
+        return mask
+
+    def regional_forward(self, x):
+        if "attn2_to_out" in self.lora_name:
+            return self.to_out_forward(x)
+
+        if self.network.mask_dic is None:  # sub_prompt_index >= 3
+            return self.default_forward(x)
+
+        # apply mask for LoRA result
+        lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
+        mask = self.get_mask_for_x(lx)
+        # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
+        lx = lx * mask
+
+        x = self.org_forward(x)
+        x = x + lx
+
+        if "attn2_to_q" in self.lora_name and self.network.is_last_network:
+            x = self.postp_to_q(x)
+
+        return x
+
+    def postp_to_q(self, x):
+        # repeat x to num_sub_prompts
+        has_real_uncond = x.size()[0] // self.network.batch_size == 3
+        qc = self.network.batch_size  # uncond
+        qc += self.network.batch_size * self.network.num_sub_prompts  # cond
+        if has_real_uncond:
+            qc += self.network.batch_size  # real_uncond
+
+        query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
+        query[: self.network.batch_size] = x[: self.network.batch_size]
+
+        for i in range(self.network.batch_size):
+            qi = self.network.batch_size + i * self.network.num_sub_prompts
+            query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
+
+        if has_real_uncond:
+            query[-self.network.batch_size :] = x[-self.network.batch_size :]
+
+        # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
+        return query
+
+    def sub_prompt_forward(self, x):
+        if x.size()[0] == self.network.batch_size:  # if uncond in text_encoder, do not apply LoRA
+            return self.org_forward(x)
+
+        emb_idx = self.network.sub_prompt_index
+        if not self.text_encoder:
+            emb_idx += self.network.batch_size
+
+        # apply sub prompt of X
+        lx = x[emb_idx :: self.network.num_sub_prompts]
+        lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
+
+        # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
+
+        x = self.org_forward(x)
+        x[emb_idx :: self.network.num_sub_prompts] += lx
+
+        return x
+
+    def to_out_forward(self, x):
+        # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
+
+        if self.network.is_last_network:
+            masks = [None] * self.network.num_sub_prompts
+            self.network.shared[self.lora_name] = (None, masks)
+        else:
+            lx, masks = self.network.shared[self.lora_name]
+
+        # call own LoRA
+        x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
+        lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
+
+        if self.network.is_last_network:
+            lx = torch.zeros(
+                (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
+            )
+            self.network.shared[self.lora_name] = (lx, masks)
+
+        # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
+        lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
+        masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
+
+        # if not last network, return x and masks
+        x = self.org_forward(x)
+        if not self.network.is_last_network:
+            return x
+
+        lx, masks = self.network.shared.pop(self.lora_name)
+
+        # if last network, combine separated x with mask weighted sum
+        has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
+
+        out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
+        out[: self.network.batch_size] = x[: self.network.batch_size]  # uncond
+        if has_real_uncond:
+            out[-self.network.batch_size :] = x[-self.network.batch_size :]  # real_uncond
+
+        # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
+        # if num_sub_prompts > num of LoRAs, fill with zero
+        for i in range(len(masks)):
+            if masks[i] is None:
+                masks[i] = torch.zeros_like(masks[0])
+
+        mask = torch.cat(masks)
+        mask_sum = torch.sum(mask, dim=0) + 1e-4
+        for i in range(self.network.batch_size):
+            # 1枚の画像ごとに処理する
+            lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
+            lx1 = lx1 * mask
+            lx1 = torch.sum(lx1, dim=0)
+
+            xi = self.network.batch_size + i * self.network.num_sub_prompts
+            x1 = x[xi : xi + self.network.num_sub_prompts]
+            x1 = x1 * mask
+            x1 = torch.sum(x1, dim=0)
+            x1 = x1 / mask_sum
+
+            x1 = x1 + lx1
+            out[self.network.batch_size + i] = x1
+
+        # print("to_out_forward", x.size(), out.size(), has_real_uncond)
+        return out
+
+
+def parse_block_lr_kwargs(nw_kwargs):
+    down_lr_weight = nw_kwargs.get("down_lr_weight", None)
+    mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
+    up_lr_weight = nw_kwargs.get("up_lr_weight", None)
+
+    # 以上のいずれにも設定がない場合は無効としてNoneを返す
+    if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
+        return None, None, None
+
+    # extract learning rate weight for each block
+    if down_lr_weight is not None:
+        # if some parameters are not set, use zero
+        if "," in down_lr_weight:
+            down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
+
+    if mid_lr_weight is not None:
+        mid_lr_weight = float(mid_lr_weight)
+
+    if up_lr_weight is not None:
+        if "," in up_lr_weight:
+            up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
+
+    down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
+        down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
+    )
+
+    return down_lr_weight, mid_lr_weight, up_lr_weight
+
+
+def create_network(
+    multiplier: float,
+    network_dim: Optional[int],
+    network_alpha: Optional[float],
+    vae: AutoencoderKL,
+    text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
+    unet,
+    neuron_dropout: Optional[float] = None,
+    **kwargs,
+):
+    if network_dim is None:
+        network_dim = 4  # default
+    if network_alpha is None:
+        network_alpha = 1.0
+
+    # extract dim/alpha for conv2d, and block dim
+    conv_dim = kwargs.get("conv_dim", None)
+    conv_alpha = kwargs.get("conv_alpha", None)
+    if conv_dim is not None:
+        conv_dim = int(conv_dim)
+        if conv_alpha is None:
+            conv_alpha = 1.0
+        else:
+            conv_alpha = float(conv_alpha)
+
+    # block dim/alpha/lr
+    block_dims = kwargs.get("block_dims", None)
+    down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
+
+    # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
+    if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
+        block_alphas = kwargs.get("block_alphas", None)
+        conv_block_dims = kwargs.get("conv_block_dims", None)
+        conv_block_alphas = kwargs.get("conv_block_alphas", None)
+
+        block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
+            block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
+        )
+
+        # remove block dim/alpha without learning rate
+        block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
+            block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
+        )
+
+    else:
+        block_alphas = None
+        conv_block_dims = None
+        conv_block_alphas = None
+
+    # rank/module dropout
+    rank_dropout = kwargs.get("rank_dropout", None)
+    if rank_dropout is not None:
+        rank_dropout = float(rank_dropout)
+    module_dropout = kwargs.get("module_dropout", None)
+    if module_dropout is not None:
+        module_dropout = float(module_dropout)
+
+    # すごく引数が多いな ( ^ω^)・・・
+    network = LoRANetwork(
+        text_encoder,
+        unet,
+        multiplier=multiplier,
+        lora_dim=network_dim,
+        alpha=network_alpha,
+        dropout=neuron_dropout,
+        rank_dropout=rank_dropout,
+        module_dropout=module_dropout,
+        conv_lora_dim=conv_dim,
+        conv_alpha=conv_alpha,
+        block_dims=block_dims,
+        block_alphas=block_alphas,
+        conv_block_dims=conv_block_dims,
+        conv_block_alphas=conv_block_alphas,
+        varbose=True,
+    )
+
+    if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
+        network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
+
+    return network
+
+
+# このメソッドは外部から呼び出される可能性を考慮しておく
+# network_dim, network_alpha にはデフォルト値が入っている。
+# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
+# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
+def get_block_dims_and_alphas(
+    block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
+):
+    num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
+
+    def parse_ints(s):
+        return [int(i) for i in s.split(",")]
+
+    def parse_floats(s):
+        return [float(i) for i in s.split(",")]
+
+    # block_dimsとblock_alphasをパースする。必ず値が入る
+    if block_dims is not None:
+        block_dims = parse_ints(block_dims)
+        assert (
+            len(block_dims) == num_total_blocks
+        ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
+    else:
+        print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
+        block_dims = [network_dim] * num_total_blocks
+
+    if block_alphas is not None:
+        block_alphas = parse_floats(block_alphas)
+        assert (
+            len(block_alphas) == num_total_blocks
+        ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
+    else:
+        print(
+            f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
+        )
+        block_alphas = [network_alpha] * num_total_blocks
+
+    # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
+    if conv_block_dims is not None:
+        conv_block_dims = parse_ints(conv_block_dims)
+        assert (
+            len(conv_block_dims) == num_total_blocks
+        ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
+
+        if conv_block_alphas is not None:
+            conv_block_alphas = parse_floats(conv_block_alphas)
+            assert (
+                len(conv_block_alphas) == num_total_blocks
+            ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
+        else:
+            if conv_alpha is None:
+                conv_alpha = 1.0
+            print(
+                f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
+            )
+            conv_block_alphas = [conv_alpha] * num_total_blocks
+    else:
+        if conv_dim is not None:
+            print(
+                f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
+            )
+            conv_block_dims = [conv_dim] * num_total_blocks
+            conv_block_alphas = [conv_alpha] * num_total_blocks
+        else:
+            conv_block_dims = None
+            conv_block_alphas = None
+
+    return block_dims, block_alphas, conv_block_dims, conv_block_alphas
+
+
+# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
+def get_block_lr_weight(
+    down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
+) -> Tuple[List[float], List[float], List[float]]:
+    # パラメータ未指定時は何もせず、今までと同じ動作とする
+    if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
+        return None, None, None
+
+    max_len = LoRANetwork.NUM_OF_BLOCKS  # フルモデル相当でのup,downの層の数
+
+    def get_list(name_with_suffix) -> List[float]:
+        import math
+
+        tokens = name_with_suffix.split("+")
+        name = tokens[0]
+        base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
+
+        if name == "cosine":
+            return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
+        elif name == "sine":
+            return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
+        elif name == "linear":
+            return [i / (max_len - 1) + base_lr for i in range(max_len)]
+        elif name == "reverse_linear":
+            return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
+        elif name == "zeros":
+            return [0.0 + base_lr] * max_len
+        else:
+            print(
+                "Unknown lr_weight argument %s is used. Valid arguments:  / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
+                % (name)
+            )
+            return None
+
+    if type(down_lr_weight) == str:
+        down_lr_weight = get_list(down_lr_weight)
+    if type(up_lr_weight) == str:
+        up_lr_weight = get_list(up_lr_weight)
+
+    if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
+        print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
+        print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
+        up_lr_weight = up_lr_weight[:max_len]
+        down_lr_weight = down_lr_weight[:max_len]
+
+    if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
+        print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
+        print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
+
+        if down_lr_weight != None and len(down_lr_weight) < max_len:
+            down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
+        if up_lr_weight != None and len(up_lr_weight) < max_len:
+            up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
+
+    if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
+        print("apply block learning rate / 階層別学習率を適用します。")
+        if down_lr_weight != None:
+            down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
+            print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
+        else:
+            print("down_lr_weight: all 1.0, すべて1.0")
+
+        if mid_lr_weight != None:
+            mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
+            print("mid_lr_weight:", mid_lr_weight)
+        else:
+            print("mid_lr_weight: 1.0")
+
+        if up_lr_weight != None:
+            up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
+            print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
+        else:
+            print("up_lr_weight: all 1.0, すべて1.0")
+
+    return down_lr_weight, mid_lr_weight, up_lr_weight
+
+
+# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
+def remove_block_dims_and_alphas(
+    block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
+):
+    # set 0 to block dim without learning rate to remove the block
+    if down_lr_weight != None:
+        for i, lr in enumerate(down_lr_weight):
+            if lr == 0:
+                block_dims[i] = 0
+                if conv_block_dims is not None:
+                    conv_block_dims[i] = 0
+    if mid_lr_weight != None:
+        if mid_lr_weight == 0:
+            block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
+            if conv_block_dims is not None:
+                conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
+    if up_lr_weight != None:
+        for i, lr in enumerate(up_lr_weight):
+            if lr == 0:
+                block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
+                if conv_block_dims is not None:
+                    conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
+
+    return block_dims, block_alphas, conv_block_dims, conv_block_alphas
+
+
+# 外部から呼び出す可能性を考慮しておく
+def get_block_index(lora_name: str) -> int:
+    block_idx = -1  # invalid lora name
+
+    m = RE_UPDOWN.search(lora_name)
+    if m:
+        g = m.groups()
+        i = int(g[1])
+        j = int(g[3])
+        if g[2] == "resnets":
+            idx = 3 * i + j
+        elif g[2] == "attentions":
+            idx = 3 * i + j
+        elif g[2] == "upsamplers" or g[2] == "downsamplers":
+            idx = 3 * i + 2
+
+        if g[0] == "down":
+            block_idx = 1 + idx  # 0に該当するLoRAは存在しない
+        elif g[0] == "up":
+            block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
+
+    elif "mid_block_" in lora_name:
+        block_idx = LoRANetwork.NUM_OF_BLOCKS  # idx=12
+
+    return block_idx
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
+    if weights_sd is None:
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file, safe_open
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+    # get dim/alpha mapping
+    modules_dim = {}
+    modules_alpha = {}
+    for key, value in weights_sd.items():
+        if "." not in key:
+            continue
+
+        lora_name = key.split(".")[0]
+        if "alpha" in key:
+            modules_alpha[lora_name] = value
+        elif "lora_down" in key:
+            dim = value.size()[0]
+            modules_dim[lora_name] = dim
+            # print(lora_name, value.size(), dim)
+
+    # support old LoRA without alpha
+    for key in modules_dim.keys():
+        if key not in modules_alpha:
+            modules_alpha[key] = modules_dim[key]
+
+    module_class = LoRAInfModule if for_inference else LoRAModule
+
+    network = LoRANetwork(
+        text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
+    )
+
+    # block lr
+    down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
+    if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
+        network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
+
+    return network, weights_sd
+
+
+class LoRANetwork(torch.nn.Module):
+    NUM_OF_BLOCKS = 12  # フルモデル相当でのup,downの層の数
+
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+    LORA_PREFIX_UNET = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+
+    # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
+    LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
+    LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
+
+    def __init__(
+        self,
+        text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
+        unet,
+        multiplier: float = 1.0,
+        lora_dim: int = 4,
+        alpha: float = 1,
+        dropout: Optional[float] = None,
+        rank_dropout: Optional[float] = None,
+        module_dropout: Optional[float] = None,
+        conv_lora_dim: Optional[int] = None,
+        conv_alpha: Optional[float] = None,
+        block_dims: Optional[List[int]] = None,
+        block_alphas: Optional[List[float]] = None,
+        conv_block_dims: Optional[List[int]] = None,
+        conv_block_alphas: Optional[List[float]] = None,
+        modules_dim: Optional[Dict[str, int]] = None,
+        modules_alpha: Optional[Dict[str, int]] = None,
+        module_class: Type[object] = LoRAModule,
+        varbose: Optional[bool] = False,
+    ) -> None:
+        """
+        LoRA network: すごく引数が多いが、パターンは以下の通り
+        1. lora_dimとalphaを指定
+        2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
+        3. block_dimsとblock_alphasを指定 :  Conv2d3x3には適用しない
+        4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
+        5. modules_dimとmodules_alphaを指定 (推論用)
+        """
+        super().__init__()
+        self.multiplier = multiplier
+
+        self.lora_dim = lora_dim
+        self.alpha = alpha
+        self.conv_lora_dim = conv_lora_dim
+        self.conv_alpha = conv_alpha
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+        if modules_dim is not None:
+            print(f"create LoRA network from weights")
+        elif block_dims is not None:
+            print(f"create LoRA network from block_dims")
+            print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+            print(f"block_dims: {block_dims}")
+            print(f"block_alphas: {block_alphas}")
+            if conv_block_dims is not None:
+                print(f"conv_block_dims: {conv_block_dims}")
+                print(f"conv_block_alphas: {conv_block_alphas}")
+        else:
+            print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+            print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+            if self.conv_lora_dim is not None:
+                print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
+
+        # create module instances
+        def create_modules(
+            is_unet: bool,
+            text_encoder_idx: Optional[int],  # None, 1, 2
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+        ) -> List[LoRAModule]:
+            prefix = (
+                self.LORA_PREFIX_UNET
+                if is_unet
+                else (
+                    self.LORA_PREFIX_TEXT_ENCODER
+                    if text_encoder_idx is None
+                    else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
+                )
+            )
+            loras = []
+            skipped = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "Linear"
+                        is_conv2d = child_module.__class__.__name__ == "Conv2d"
+                        is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+                        if is_linear or is_conv2d:
+                            lora_name = prefix + "." + name + "." + child_name
+                            lora_name = lora_name.replace(".", "_")
+
+                            dim = None
+                            alpha = None
+
+                            if modules_dim is not None:
+                                # モジュール指定あり
+                                if lora_name in modules_dim:
+                                    dim = modules_dim[lora_name]
+                                    alpha = modules_alpha[lora_name]
+                            elif is_unet and block_dims is not None:
+                                # U-Netでblock_dims指定あり
+                                block_idx = get_block_index(lora_name)
+                                if is_linear or is_conv2d_1x1:
+                                    dim = block_dims[block_idx]
+                                    alpha = block_alphas[block_idx]
+                                elif conv_block_dims is not None:
+                                    dim = conv_block_dims[block_idx]
+                                    alpha = conv_block_alphas[block_idx]
+                            else:
+                                # 通常、すべて対象とする
+                                if is_linear or is_conv2d_1x1:
+                                    dim = self.lora_dim
+                                    alpha = self.alpha
+                                elif self.conv_lora_dim is not None:
+                                    dim = self.conv_lora_dim
+                                    alpha = self.conv_alpha
+
+                            if dim is None or dim == 0:
+                                # skipした情報を出力
+                                if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
+                                    skipped.append(lora_name)
+                                continue
+
+                            lora = module_class(
+                                lora_name,
+                                child_module,
+                                self.multiplier,
+                                dim,
+                                alpha,
+                                dropout=dropout,
+                                rank_dropout=rank_dropout,
+                                module_dropout=module_dropout,
+                            )
+                            loras.append(lora)
+            return loras, skipped
+
+        text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+        # create LoRA for text encoder
+        # 毎回すべてのモジュールを作るのは無駄なので要検討
+        self.text_encoder_loras = []
+        skipped_te = []
+        for i, text_encoder in enumerate(text_encoders):
+            if len(text_encoders) > 1:
+                index = i + 1
+                print(f"create LoRA for Text Encoder {index}:")
+            else:
+                index = None
+                print(f"create LoRA for Text Encoder:")
+
+            text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+            self.text_encoder_loras.extend(text_encoder_loras)
+            skipped_te += skipped
+        print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+        # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
+        target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
+        if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
+            target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
+        print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+        skipped = skipped_te + skipped_un
+        if varbose and len(skipped) > 0:
+            print(
+                f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
+            )
+            for name in skipped:
+                print(f"\t{name}")
+
+        self.up_lr_weight: List[float] = None
+        self.down_lr_weight: List[float] = None
+        self.mid_lr_weight: float = None
+        self.block_lr = False
+
+        # assertion
+        names = set()
+        for lora in self.text_encoder_loras + self.unet_loras:
+            assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+            names.add(lora.lora_name)
+
+    def set_multiplier(self, multiplier):
+        self.multiplier = multiplier
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.multiplier = self.multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.apply_to()
+            self.add_module(lora.lora_name, lora)
+
+    # マージできるかどうかを返す
+    def is_mergeable(self):
+        return True
+
+    # TODO refactor to common function with apply_to
+    def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
+        apply_text_encoder = apply_unet = False
+        for key in weights_sd.keys():
+            if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
+                apply_text_encoder = True
+            elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
+                apply_unet = True
+
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            sd_for_lora = {}
+            for key in weights_sd.keys():
+                if key.startswith(lora.lora_name):
+                    sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+            lora.merge_to(sd_for_lora, dtype, device)
+
+        print(f"weights are merged")
+
+    # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
+    def set_block_lr_weight(
+        self,
+        up_lr_weight: List[float] = None,
+        mid_lr_weight: float = None,
+        down_lr_weight: List[float] = None,
+    ):
+        self.block_lr = True
+        self.down_lr_weight = down_lr_weight
+        self.mid_lr_weight = mid_lr_weight
+        self.up_lr_weight = up_lr_weight
+
+    def get_lr_weight(self, lora: LoRAModule) -> float:
+        lr_weight = 1.0
+        block_idx = get_block_index(lora.lora_name)
+        if block_idx < 0:
+            return lr_weight
+
+        if block_idx < LoRANetwork.NUM_OF_BLOCKS:
+            if self.down_lr_weight != None:
+                lr_weight = self.down_lr_weight[block_idx]
+        elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
+            if self.mid_lr_weight != None:
+                lr_weight = self.mid_lr_weight
+        elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
+            if self.up_lr_weight != None:
+                lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
+
+        return lr_weight
+
+    # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
+    def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+        self.requires_grad_(True)
+        all_params = []
+
+        def enumerate_params(loras):
+            params = []
+            for lora in loras:
+                params.extend(lora.parameters())
+            return params
+
+        if self.text_encoder_loras:
+            param_data = {"params": enumerate_params(self.text_encoder_loras)}
+            if text_encoder_lr is not None:
+                param_data["lr"] = text_encoder_lr
+            all_params.append(param_data)
+
+        if self.unet_loras:
+            if self.block_lr:
+                # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
+                block_idx_to_lora = {}
+                for lora in self.unet_loras:
+                    idx = get_block_index(lora.lora_name)
+                    if idx not in block_idx_to_lora:
+                        block_idx_to_lora[idx] = []
+                    block_idx_to_lora[idx].append(lora)
+
+                # blockごとにパラメータを設定する
+                for idx, block_loras in block_idx_to_lora.items():
+                    param_data = {"params": enumerate_params(block_loras)}
+
+                    if unet_lr is not None:
+                        param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
+                    elif default_lr is not None:
+                        param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
+                    if ("lr" in param_data) and (param_data["lr"] == 0):
+                        continue
+                    all_params.append(param_data)
+
+            else:
+                param_data = {"params": enumerate_params(self.unet_loras)}
+                if unet_lr is not None:
+                    param_data["lr"] = unet_lr
+                all_params.append(param_data)
+
+        return all_params
+
+    def enable_gradient_checkpointing(self):
+        # not supported
+        pass
+
+    def prepare_grad_etc(self, text_encoder, unet):
+        self.requires_grad_(True)
+
+    def on_epoch_start(self, text_encoder, unet):
+        self.train()
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+            from library import train_util
+
+            # Precalculate model hashes to save time on indexing
+            if metadata is None:
+                metadata = {}
+            model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+            metadata["sshs_model_hash"] = model_hash
+            metadata["sshs_legacy_hash"] = legacy_hash
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+    # mask is a tensor with values from 0 to 1
+    def set_region(self, sub_prompt_index, is_last_network, mask):
+        if mask.max() == 0:
+            mask = torch.ones_like(mask)
+
+        self.mask = mask
+        self.sub_prompt_index = sub_prompt_index
+        self.is_last_network = is_last_network
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.set_network(self)
+
+    def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
+        self.batch_size = batch_size
+        self.num_sub_prompts = num_sub_prompts
+        self.current_size = (height, width)
+        self.shared = shared
+
+        # create masks
+        mask = self.mask
+        mask_dic = {}
+        mask = mask.unsqueeze(0).unsqueeze(1)  # b(1),c(1),h,w
+        ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
+        dtype = ref_weight.dtype
+        device = ref_weight.device
+
+        def resize_add(mh, mw):
+            # print(mh, mw, mh * mw)
+            m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear")  # doesn't work in bf16
+            m = m.to(device, dtype=dtype)
+            mask_dic[mh * mw] = m
+
+        h = height // 8
+        w = width // 8
+        for _ in range(4):
+            resize_add(h, w)
+            if h % 2 == 1 or w % 2 == 1:  # add extra shape if h/w is not divisible by 2
+                resize_add(h + h % 2, w + w % 2)
+            h = (h + 1) // 2
+            w = (w + 1) // 2
+
+        self.mask_dic = mask_dic
+
+    def backup_weights(self):
+        # 重みのバックアップを行う
+        loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+        for lora in loras:
+            org_module = lora.org_module_ref[0]
+            if not hasattr(org_module, "_lora_org_weight"):
+                sd = org_module.state_dict()
+                org_module._lora_org_weight = sd["weight"].detach().clone()
+                org_module._lora_restored = True
+
+    def restore_weights(self):
+        # 重みのリストアを行う
+        loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+        for lora in loras:
+            org_module = lora.org_module_ref[0]
+            if not org_module._lora_restored:
+                sd = org_module.state_dict()
+                sd["weight"] = org_module._lora_org_weight
+                org_module.load_state_dict(sd)
+                org_module._lora_restored = True
+
+    def pre_calculation(self):
+        # 事前計算を行う
+        loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+        for lora in loras:
+            org_module = lora.org_module_ref[0]
+            sd = org_module.state_dict()
+
+            org_weight = sd["weight"]
+            lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+            sd["weight"] = org_weight + lora_weight
+            assert sd["weight"].shape == org_weight.shape
+            org_module.load_state_dict(sd)
+
+            org_module._lora_restored = False
+            lora.enabled = False
+
+    def apply_max_norm_regularization(self, max_norm_value, device):
+        downkeys = []
+        upkeys = []
+        alphakeys = []
+        norms = []
+        keys_scaled = 0
+
+        state_dict = self.state_dict()
+        for key in state_dict.keys():
+            if "lora_down" in key and "weight" in key:
+                downkeys.append(key)
+                upkeys.append(key.replace("lora_down", "lora_up"))
+                alphakeys.append(key.replace("lora_down.weight", "alpha"))
+
+        for i in range(len(downkeys)):
+            down = state_dict[downkeys[i]].to(device)
+            up = state_dict[upkeys[i]].to(device)
+            alpha = state_dict[alphakeys[i]].to(device)
+            dim = down.shape[0]
+            scale = alpha / dim
+
+            if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+                updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+            elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+                updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
+            else:
+                updown = up @ down
+
+            updown *= scale
+
+            norm = updown.norm().clamp(min=max_norm_value / 2)
+            desired = torch.clamp(norm, max=max_norm_value)
+            ratio = desired.cpu() / norm.cpu()
+            sqrt_ratio = ratio**0.5
+            if ratio != 1:
+                keys_scaled += 1
+                state_dict[upkeys[i]] *= sqrt_ratio
+                state_dict[downkeys[i]] *= sqrt_ratio
+            scalednorm = updown.norm() * ratio
+            norms.append(scalednorm.item())
+
+        return keys_scaled, sum(norms) / len(norms), max(norms)
diff --git a/external/llite/networks/lora_diffusers.py b/external/llite/networks/lora_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..47d75ac4d103362a0d3e17c294dbf35d690f99fa
--- /dev/null
+++ b/external/llite/networks/lora_diffusers.py
@@ -0,0 +1,609 @@
+# Diffusersで動くLoRA。このファイル単独で完結する。
+# LoRA module for Diffusers. This file works independently.
+
+import bisect
+import math
+import random
+from typing import Any, Dict, List, Mapping, Optional, Union
+from diffusers import UNet2DConditionModel
+import numpy as np
+from tqdm import tqdm
+from transformers import CLIPTextModel
+import torch
+
+
+def make_unet_conversion_map() -> Dict[str, str]:
+    unet_conversion_map_layer = []
+
+    for i in range(3):  # num_blocks is 3 in sdxl
+        # loop over downblocks/upblocks
+        for j in range(2):
+            # loop over resnets/attentions for downblocks
+            hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+            sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+            unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+            if i < 3:
+                # no attention layers in down_blocks.3
+                hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+                sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+                unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+        for j in range(3):
+            # loop over resnets/attentions for upblocks
+            hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+            sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+            unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+            # if i > 0: commentout for sdxl
+            # no attention layers in up_blocks.0
+            hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+            sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+            unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+        if i < 3:
+            # no downsample in down_blocks.3
+            hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+            sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+            unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+            # no upsample in up_blocks.3
+            hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+            sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}."  # change for sdxl
+            unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+    hf_mid_atn_prefix = "mid_block.attentions.0."
+    sd_mid_atn_prefix = "middle_block.1."
+    unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+    for j in range(2):
+        hf_mid_res_prefix = f"mid_block.resnets.{j}."
+        sd_mid_res_prefix = f"middle_block.{2*j}."
+        unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+    unet_conversion_map_resnet = [
+        # (stable-diffusion, HF Diffusers)
+        ("in_layers.0.", "norm1."),
+        ("in_layers.2.", "conv1."),
+        ("out_layers.0.", "norm2."),
+        ("out_layers.3.", "conv2."),
+        ("emb_layers.1.", "time_emb_proj."),
+        ("skip_connection.", "conv_shortcut."),
+    ]
+
+    unet_conversion_map = []
+    for sd, hf in unet_conversion_map_layer:
+        if "resnets" in hf:
+            for sd_res, hf_res in unet_conversion_map_resnet:
+                unet_conversion_map.append((sd + sd_res, hf + hf_res))
+        else:
+            unet_conversion_map.append((sd, hf))
+
+    for j in range(2):
+        hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
+        sd_time_embed_prefix = f"time_embed.{j*2}."
+        unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
+
+    for j in range(2):
+        hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
+        sd_label_embed_prefix = f"label_emb.0.{j*2}."
+        unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
+
+    unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
+    unet_conversion_map.append(("out.0.", "conv_norm_out."))
+    unet_conversion_map.append(("out.2.", "conv_out."))
+
+    sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
+    return sd_hf_conversion_map
+
+
+UNET_CONVERSION_MAP = make_unet_conversion_map()
+
+
+class LoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+    ):
+        """if alpha == 0 or None, alpha is rank (no scaling)."""
+        super().__init__()
+        self.lora_name = lora_name
+
+        if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
+            in_dim = org_module.in_channels
+            out_dim = org_module.out_channels
+        else:
+            in_dim = org_module.in_features
+            out_dim = org_module.out_features
+
+        self.lora_dim = lora_dim
+
+        if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
+            kernel_size = org_module.kernel_size
+            stride = org_module.stride
+            padding = org_module.padding
+            self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+            self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+        else:
+            self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+            self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))  # 勾配計算に含めない / not included in gradient calculation
+
+        # same as microsoft's
+        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+        torch.nn.init.zeros_(self.lora_up.weight)
+
+        self.multiplier = multiplier
+        self.org_module = [org_module]
+        self.enabled = True
+        self.network: LoRANetwork = None
+        self.org_forward = None
+
+    # override org_module's forward method
+    def apply_to(self, multiplier=None):
+        if multiplier is not None:
+            self.multiplier = multiplier
+        if self.org_forward is None:
+            self.org_forward = self.org_module[0].forward
+            self.org_module[0].forward = self.forward
+
+    # restore org_module's forward method
+    def unapply_to(self):
+        if self.org_forward is not None:
+            self.org_module[0].forward = self.org_forward
+
+    # forward with lora
+    # scale is used LoRACompatibleConv, but we ignore it because we have multiplier
+    def forward(self, x, scale=1.0):
+        if not self.enabled:
+            return self.org_forward(x)
+        return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
+
+    def set_network(self, network):
+        self.network = network
+
+    # merge lora weight to org weight
+    def merge_to(self, multiplier=1.0):
+        # get lora weight
+        lora_weight = self.get_weight(multiplier)
+
+        # get org weight
+        org_sd = self.org_module[0].state_dict()
+        org_weight = org_sd["weight"]
+        weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
+
+        # set weight to org_module
+        org_sd["weight"] = weight
+        self.org_module[0].load_state_dict(org_sd)
+
+    # restore org weight from lora weight
+    def restore_from(self, multiplier=1.0):
+        # get lora weight
+        lora_weight = self.get_weight(multiplier)
+
+        # get org weight
+        org_sd = self.org_module[0].state_dict()
+        org_weight = org_sd["weight"]
+        weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
+
+        # set weight to org_module
+        org_sd["weight"] = weight
+        self.org_module[0].load_state_dict(org_sd)
+
+    # return lora weight
+    def get_weight(self, multiplier=None):
+        if multiplier is None:
+            multiplier = self.multiplier
+
+        # get up/down weight from module
+        up_weight = self.lora_up.weight.to(torch.float)
+        down_weight = self.lora_down.weight.to(torch.float)
+
+        # pre-calculated weight
+        if len(down_weight.size()) == 2:
+            # linear
+            weight = self.multiplier * (up_weight @ down_weight) * self.scale
+        elif down_weight.size()[2:4] == (1, 1):
+            # conv2d 1x1
+            weight = (
+                self.multiplier
+                * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                * self.scale
+            )
+        else:
+            # conv2d 3x3
+            conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+            weight = self.multiplier * conved * self.scale
+
+        return weight
+
+
+# Create network from weights for inference, weights are not loaded here
+def create_network_from_weights(
+    text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
+):
+    # get dim/alpha mapping
+    modules_dim = {}
+    modules_alpha = {}
+    for key, value in weights_sd.items():
+        if "." not in key:
+            continue
+
+        lora_name = key.split(".")[0]
+        if "alpha" in key:
+            modules_alpha[lora_name] = value
+        elif "lora_down" in key:
+            dim = value.size()[0]
+            modules_dim[lora_name] = dim
+            # print(lora_name, value.size(), dim)
+
+    # support old LoRA without alpha
+    for key in modules_dim.keys():
+        if key not in modules_alpha:
+            modules_alpha[key] = modules_dim[key]
+
+    return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
+
+
+def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
+    text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
+    unet = pipe.unet
+
+    lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
+    lora_network.load_state_dict(weights_sd)
+    lora_network.merge_to(multiplier=multiplier)
+
+
+# block weightや学習に対応しない簡易版 / simple version without block weight and training
+class LoRANetwork(torch.nn.Module):
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+    LORA_PREFIX_UNET = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+
+    # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
+    LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
+    LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
+
+    def __init__(
+        self,
+        text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
+        unet: UNet2DConditionModel,
+        multiplier: float = 1.0,
+        modules_dim: Optional[Dict[str, int]] = None,
+        modules_alpha: Optional[Dict[str, int]] = None,
+        varbose: Optional[bool] = False,
+    ) -> None:
+        super().__init__()
+        self.multiplier = multiplier
+
+        print(f"create LoRA network from weights")
+
+        # convert SDXL Stability AI's U-Net modules to Diffusers
+        converted = self.convert_unet_modules(modules_dim, modules_alpha)
+        if converted:
+            print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
+
+        # create module instances
+        def create_modules(
+            is_unet: bool,
+            text_encoder_idx: Optional[int],  # None, 1, 2
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+        ) -> List[LoRAModule]:
+            prefix = (
+                self.LORA_PREFIX_UNET
+                if is_unet
+                else (
+                    self.LORA_PREFIX_TEXT_ENCODER
+                    if text_encoder_idx is None
+                    else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
+                )
+            )
+            loras = []
+            skipped = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = (
+                            child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
+                        )
+                        is_conv2d = (
+                            child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
+                        )
+
+                        if is_linear or is_conv2d:
+                            lora_name = prefix + "." + name + "." + child_name
+                            lora_name = lora_name.replace(".", "_")
+
+                            if lora_name not in modules_dim:
+                                # print(f"skipped {lora_name} (not found in modules_dim)")
+                                skipped.append(lora_name)
+                                continue
+
+                            dim = modules_dim[lora_name]
+                            alpha = modules_alpha[lora_name]
+                            lora = LoRAModule(
+                                lora_name,
+                                child_module,
+                                self.multiplier,
+                                dim,
+                                alpha,
+                            )
+                            loras.append(lora)
+            return loras, skipped
+
+        text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+        # create LoRA for text encoder
+        # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
+        self.text_encoder_loras: List[LoRAModule] = []
+        skipped_te = []
+        for i, text_encoder in enumerate(text_encoders):
+            if len(text_encoders) > 1:
+                index = i + 1
+            else:
+                index = None
+
+            text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+            self.text_encoder_loras.extend(text_encoder_loras)
+            skipped_te += skipped
+        print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+        if len(skipped_te) > 0:
+            print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
+
+        # extend U-Net target modules to include Conv2d 3x3
+        target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        self.unet_loras: List[LoRAModule]
+        self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
+        print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+        if len(skipped_un) > 0:
+            print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
+
+        # assertion
+        names = set()
+        for lora in self.text_encoder_loras + self.unet_loras:
+            names.add(lora.lora_name)
+        for lora_name in modules_dim.keys():
+            assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
+
+        # make to work load_state_dict
+        for lora in self.text_encoder_loras + self.unet_loras:
+            self.add_module(lora.lora_name, lora)
+
+    # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
+    def convert_unet_modules(self, modules_dim, modules_alpha):
+        converted_count = 0
+        not_converted_count = 0
+
+        map_keys = list(UNET_CONVERSION_MAP.keys())
+        map_keys.sort()
+
+        for key in list(modules_dim.keys()):
+            if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
+                search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
+                position = bisect.bisect_right(map_keys, search_key)
+                map_key = map_keys[position - 1]
+                if search_key.startswith(map_key):
+                    new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
+                    modules_dim[new_key] = modules_dim[key]
+                    modules_alpha[new_key] = modules_alpha[key]
+                    del modules_dim[key]
+                    del modules_alpha[key]
+                    converted_count += 1
+                else:
+                    not_converted_count += 1
+        assert (
+            converted_count == 0 or not_converted_count == 0
+        ), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
+        return converted_count
+
+    def set_multiplier(self, multiplier):
+        self.multiplier = multiplier
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.multiplier = self.multiplier
+
+    def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+            for lora in self.text_encoder_loras:
+                lora.apply_to(multiplier)
+        if apply_unet:
+            print("enable LoRA for U-Net")
+            for lora in self.unet_loras:
+                lora.apply_to(multiplier)
+
+    def unapply_to(self):
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.unapply_to()
+
+    def merge_to(self, multiplier=1.0):
+        print("merge LoRA weights to original weights")
+        for lora in tqdm(self.text_encoder_loras + self.unet_loras):
+            lora.merge_to(multiplier)
+        print(f"weights are merged")
+
+    def restore_from(self, multiplier=1.0):
+        print("restore LoRA weights from original weights")
+        for lora in tqdm(self.text_encoder_loras + self.unet_loras):
+            lora.restore_from(multiplier)
+        print(f"weights are restored")
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        # convert SDXL Stability AI's state dict to Diffusers' based state dict
+        map_keys = list(UNET_CONVERSION_MAP.keys())  # prefix of U-Net modules
+        map_keys.sort()
+        for key in list(state_dict.keys()):
+            if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
+                search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
+                position = bisect.bisect_right(map_keys, search_key)
+                map_key = map_keys[position - 1]
+                if search_key.startswith(map_key):
+                    new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
+                    state_dict[new_key] = state_dict[key]
+                    del state_dict[key]
+
+        # in case of V2, some weights have different shape, so we need to convert them
+        # because V2 LoRA is based on U-Net created by use_linear_projection=False
+        my_state_dict = self.state_dict()
+        for key in state_dict.keys():
+            if state_dict[key].size() != my_state_dict[key].size():
+                # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
+                state_dict[key] = state_dict[key].view(my_state_dict[key].size())
+
+        return super().load_state_dict(state_dict, strict)
+
+
+if __name__ == "__main__":
+    # sample code to use LoRANetwork
+    import os
+    import argparse
+    from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
+    import torch
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
+    parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
+    parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
+    parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
+    parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
+    parser.add_argument("--seed", type=int, default=0, help="random seed")
+    args = parser.parse_args()
+
+    image_prefix = args.model_id.replace("/", "_") + "_"
+
+    # load Diffusers model
+    print(f"load model from {args.model_id}")
+    pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
+    if args.sdxl:
+        # use_safetensors=True does not work with 0.18.2
+        pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
+    else:
+        pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
+    pipe.to(device)
+    pipe.set_use_memory_efficient_attention_xformers(True)
+
+    text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
+
+    # load LoRA weights
+    print(f"load LoRA weights from {args.lora_weights}")
+    if os.path.splitext(args.lora_weights)[1] == ".safetensors":
+        from safetensors.torch import load_file
+
+        lora_sd = load_file(args.lora_weights)
+    else:
+        lora_sd = torch.load(args.lora_weights)
+
+    # create by LoRA weights and load weights
+    print(f"create LoRA network")
+    lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
+
+    print(f"load LoRA network weights")
+    lora_network.load_state_dict(lora_sd)
+
+    lora_network.to(device, dtype=pipe.unet.dtype)  # required to apply_to. merge_to works without this
+
+    # 必要があれば、元のモデルの重みをバックアップしておく
+    # back-up unet/text encoder weights if necessary
+    def detach_and_move_to_cpu(state_dict):
+        for k, v in state_dict.items():
+            state_dict[k] = v.detach().cpu()
+        return state_dict
+
+    org_unet_sd = pipe.unet.state_dict()
+    detach_and_move_to_cpu(org_unet_sd)
+
+    org_text_encoder_sd = pipe.text_encoder.state_dict()
+    detach_and_move_to_cpu(org_text_encoder_sd)
+
+    if args.sdxl:
+        org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
+        detach_and_move_to_cpu(org_text_encoder_2_sd)
+
+    def seed_everything(seed):
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        np.random.seed(seed)
+        random.seed(seed)
+
+    # create image with original weights
+    print(f"create image with original weights")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "original.png")
+
+    # apply LoRA network to the model: slower than merge_to, but can be reverted easily
+    print(f"apply LoRA network to the model")
+    lora_network.apply_to(multiplier=1.0)
+
+    print(f"create image with applied LoRA")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "applied_lora.png")
+
+    # unapply LoRA network to the model
+    print(f"unapply LoRA network to the model")
+    lora_network.unapply_to()
+
+    print(f"create image with unapplied LoRA")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "unapplied_lora.png")
+
+    # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
+    print(f"merge LoRA network to the model")
+    lora_network.merge_to(multiplier=1.0)
+
+    print(f"create image with LoRA")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "merged_lora.png")
+
+    # restore (unmerge) LoRA weights: numerically unstable
+    # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
+    # 保存したstate_dictから元の重みを復元するのが確実
+    print(f"restore (unmerge) LoRA weights")
+    lora_network.restore_from(multiplier=1.0)
+
+    print(f"create image without LoRA")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "unmerged_lora.png")
+
+    # restore original weights
+    print(f"restore original weights")
+    pipe.unet.load_state_dict(org_unet_sd)
+    pipe.text_encoder.load_state_dict(org_text_encoder_sd)
+    if args.sdxl:
+        pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
+
+    print(f"create image with restored original weights")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "restore_original.png")
+
+    # use convenience function to merge LoRA weights
+    print(f"merge LoRA weights with convenience function")
+    merge_lora_weights(pipe, lora_sd, multiplier=1.0)
+
+    print(f"create image with merged LoRA weights")
+    seed_everything(args.seed)
+    image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
+    image.save(image_prefix + "convenience_merged_lora.png")
diff --git a/external/llite/networks/lora_fa.py b/external/llite/networks/lora_fa.py
new file mode 100644
index 0000000000000000000000000000000000000000..a357d7f7fcee2d9435af9a86324b03b111fd8aad
--- /dev/null
+++ b/external/llite/networks/lora_fa.py
@@ -0,0 +1,1241 @@
+# LoRA network module
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+
+# temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303
+# need to be refactored and merged to lora.py
+
+import math
+import os
+from typing import Dict, List, Optional, Tuple, Type, Union
+from diffusers import AutoencoderKL
+from transformers import CLIPTextModel
+import numpy as np
+import torch
+import re
+
+
+RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
+
+
+class LoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+        dropout=None,
+        rank_dropout=None,
+        module_dropout=None,
+    ):
+        """if alpha == 0 or None, alpha is rank (no scaling)."""
+        super().__init__()
+        self.lora_name = lora_name
+
+        if org_module.__class__.__name__ == "Conv2d":
+            in_dim = org_module.in_channels
+            out_dim = org_module.out_channels
+        else:
+            in_dim = org_module.in_features
+            out_dim = org_module.out_features
+
+        # if limit_rank:
+        #   self.lora_dim = min(lora_dim, in_dim, out_dim)
+        #   if self.lora_dim != lora_dim:
+        #     print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
+        # else:
+        self.lora_dim = lora_dim
+
+        if org_module.__class__.__name__ == "Conv2d":
+            kernel_size = org_module.kernel_size
+            stride = org_module.stride
+            padding = org_module.padding
+            self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+            self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+        else:
+            self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+            self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
+
+        # # same as microsoft's
+        # torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+
+        # according to the paper, initialize LoRA-A (down) as normal distribution
+        torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim)))
+
+        torch.nn.init.zeros_(self.lora_up.weight)
+
+        self.multiplier = multiplier
+        self.org_module = org_module  # remove in applying
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+    def get_trainable_params(self):
+        params = self.named_parameters()
+        trainable_params = []
+        for param in params:
+            if param[0] == "lora_up.weight":  # up only
+                trainable_params.append(param[1])
+        return trainable_params
+
+    def requires_grad_(self, requires_grad: bool = True):
+        self.lora_up.requires_grad_(requires_grad)
+        self.lora_down.requires_grad_(False)
+        return self
+
+    def apply_to(self):
+        self.org_forward = self.org_module.forward
+        self.org_module.forward = self.forward
+        del self.org_module
+
+    def forward(self, x):
+        org_forwarded = self.org_forward(x)
+
+        # module dropout
+        if self.module_dropout is not None and self.training:
+            if torch.rand(1) < self.module_dropout:
+                return org_forwarded
+
+        lx = self.lora_down(x)
+
+        # normal dropout
+        if self.dropout is not None and self.training:
+            lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+        # rank dropout
+        if self.rank_dropout is not None and self.training:
+            mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+            if len(lx.size()) == 3:
+                mask = mask.unsqueeze(1)  # for Text Encoder
+            elif len(lx.size()) == 4:
+                mask = mask.unsqueeze(-1).unsqueeze(-1)  # for Conv2d
+            lx = lx * mask
+
+            # scaling for rank dropout: treat as if the rank is changed
+            # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
+            scale = self.scale * (1.0 / (1.0 - self.rank_dropout))  # redundant for readability
+        else:
+            scale = self.scale
+
+        lx = self.lora_up(lx)
+
+        return org_forwarded + lx * self.multiplier * scale
+
+
+class LoRAInfModule(LoRAModule):
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+        **kwargs,
+    ):
+        # no dropout for inference
+        super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
+
+        self.org_module_ref = [org_module]  # 後から参照できるように
+        self.enabled = True
+
+        # check regional or not by lora_name
+        self.text_encoder = False
+        if lora_name.startswith("lora_te_"):
+            self.regional = False
+            self.use_sub_prompt = True
+            self.text_encoder = True
+        elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
+            self.regional = False
+            self.use_sub_prompt = True
+        elif "time_emb" in lora_name:
+            self.regional = False
+            self.use_sub_prompt = False
+        else:
+            self.regional = True
+            self.use_sub_prompt = False
+
+        self.network: LoRANetwork = None
+
+    def set_network(self, network):
+        self.network = network
+
+    # freezeしてマージする
+    def merge_to(self, sd, dtype, device):
+        # get up/down weight
+        up_weight = sd["lora_up.weight"].to(torch.float).to(device)
+        down_weight = sd["lora_down.weight"].to(torch.float).to(device)
+
+        # extract weight from org_module
+        org_sd = self.org_module.state_dict()
+        weight = org_sd["weight"].to(torch.float)
+
+        # merge weight
+        if len(weight.size()) == 2:
+            # linear
+            weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+        elif down_weight.size()[2:4] == (1, 1):
+            # conv2d 1x1
+            weight = (
+                weight
+                + self.multiplier
+                * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                * self.scale
+            )
+        else:
+            # conv2d 3x3
+            conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+            # print(conved.size(), weight.size(), module.stride, module.padding)
+            weight = weight + self.multiplier * conved * self.scale
+
+        # set weight to org_module
+        org_sd["weight"] = weight.to(dtype)
+        self.org_module.load_state_dict(org_sd)
+
+    # 復元できるマージのため、このモジュールのweightを返す
+    def get_weight(self, multiplier=None):
+        if multiplier is None:
+            multiplier = self.multiplier
+
+        # get up/down weight from module
+        up_weight = self.lora_up.weight.to(torch.float)
+        down_weight = self.lora_down.weight.to(torch.float)
+
+        # pre-calculated weight
+        if len(down_weight.size()) == 2:
+            # linear
+            weight = self.multiplier * (up_weight @ down_weight) * self.scale
+        elif down_weight.size()[2:4] == (1, 1):
+            # conv2d 1x1
+            weight = (
+                self.multiplier
+                * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                * self.scale
+            )
+        else:
+            # conv2d 3x3
+            conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+            weight = self.multiplier * conved * self.scale
+
+        return weight
+
+    def set_region(self, region):
+        self.region = region
+        self.region_mask = None
+
+    def default_forward(self, x):
+        # print("default_forward", self.lora_name, x.size())
+        return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
+
+    def forward(self, x):
+        if not self.enabled:
+            return self.org_forward(x)
+
+        if self.network is None or self.network.sub_prompt_index is None:
+            return self.default_forward(x)
+        if not self.regional and not self.use_sub_prompt:
+            return self.default_forward(x)
+
+        if self.regional:
+            return self.regional_forward(x)
+        else:
+            return self.sub_prompt_forward(x)
+
+    def get_mask_for_x(self, x):
+        # calculate size from shape of x
+        if len(x.size()) == 4:
+            h, w = x.size()[2:4]
+            area = h * w
+        else:
+            area = x.size()[1]
+
+        mask = self.network.mask_dic[area]
+        if mask is None:
+            raise ValueError(f"mask is None for resolution {area}")
+        if len(x.size()) != 4:
+            mask = torch.reshape(mask, (1, -1, 1))
+        return mask
+
+    def regional_forward(self, x):
+        if "attn2_to_out" in self.lora_name:
+            return self.to_out_forward(x)
+
+        if self.network.mask_dic is None:  # sub_prompt_index >= 3
+            return self.default_forward(x)
+
+        # apply mask for LoRA result
+        lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
+        mask = self.get_mask_for_x(lx)
+        # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
+        lx = lx * mask
+
+        x = self.org_forward(x)
+        x = x + lx
+
+        if "attn2_to_q" in self.lora_name and self.network.is_last_network:
+            x = self.postp_to_q(x)
+
+        return x
+
+    def postp_to_q(self, x):
+        # repeat x to num_sub_prompts
+        has_real_uncond = x.size()[0] // self.network.batch_size == 3
+        qc = self.network.batch_size  # uncond
+        qc += self.network.batch_size * self.network.num_sub_prompts  # cond
+        if has_real_uncond:
+            qc += self.network.batch_size  # real_uncond
+
+        query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
+        query[: self.network.batch_size] = x[: self.network.batch_size]
+
+        for i in range(self.network.batch_size):
+            qi = self.network.batch_size + i * self.network.num_sub_prompts
+            query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
+
+        if has_real_uncond:
+            query[-self.network.batch_size :] = x[-self.network.batch_size :]
+
+        # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
+        return query
+
+    def sub_prompt_forward(self, x):
+        if x.size()[0] == self.network.batch_size:  # if uncond in text_encoder, do not apply LoRA
+            return self.org_forward(x)
+
+        emb_idx = self.network.sub_prompt_index
+        if not self.text_encoder:
+            emb_idx += self.network.batch_size
+
+        # apply sub prompt of X
+        lx = x[emb_idx :: self.network.num_sub_prompts]
+        lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
+
+        # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
+
+        x = self.org_forward(x)
+        x[emb_idx :: self.network.num_sub_prompts] += lx
+
+        return x
+
+    def to_out_forward(self, x):
+        # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
+
+        if self.network.is_last_network:
+            masks = [None] * self.network.num_sub_prompts
+            self.network.shared[self.lora_name] = (None, masks)
+        else:
+            lx, masks = self.network.shared[self.lora_name]
+
+        # call own LoRA
+        x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
+        lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
+
+        if self.network.is_last_network:
+            lx = torch.zeros(
+                (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
+            )
+            self.network.shared[self.lora_name] = (lx, masks)
+
+        # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
+        lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
+        masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
+
+        # if not last network, return x and masks
+        x = self.org_forward(x)
+        if not self.network.is_last_network:
+            return x
+
+        lx, masks = self.network.shared.pop(self.lora_name)
+
+        # if last network, combine separated x with mask weighted sum
+        has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
+
+        out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
+        out[: self.network.batch_size] = x[: self.network.batch_size]  # uncond
+        if has_real_uncond:
+            out[-self.network.batch_size :] = x[-self.network.batch_size :]  # real_uncond
+
+        # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
+        # for i in range(len(masks)):
+        #     if masks[i] is None:
+        #         masks[i] = torch.zeros_like(masks[-1])
+
+        mask = torch.cat(masks)
+        mask_sum = torch.sum(mask, dim=0) + 1e-4
+        for i in range(self.network.batch_size):
+            # 1枚の画像ごとに処理する
+            lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
+            lx1 = lx1 * mask
+            lx1 = torch.sum(lx1, dim=0)
+
+            xi = self.network.batch_size + i * self.network.num_sub_prompts
+            x1 = x[xi : xi + self.network.num_sub_prompts]
+            x1 = x1 * mask
+            x1 = torch.sum(x1, dim=0)
+            x1 = x1 / mask_sum
+
+            x1 = x1 + lx1
+            out[self.network.batch_size + i] = x1
+
+        # print("to_out_forward", x.size(), out.size(), has_real_uncond)
+        return out
+
+
+def parse_block_lr_kwargs(nw_kwargs):
+    down_lr_weight = nw_kwargs.get("down_lr_weight", None)
+    mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
+    up_lr_weight = nw_kwargs.get("up_lr_weight", None)
+
+    # 以上のいずれにも設定がない場合は無効としてNoneを返す
+    if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
+        return None, None, None
+
+    # extract learning rate weight for each block
+    if down_lr_weight is not None:
+        # if some parameters are not set, use zero
+        if "," in down_lr_weight:
+            down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
+
+    if mid_lr_weight is not None:
+        mid_lr_weight = float(mid_lr_weight)
+
+    if up_lr_weight is not None:
+        if "," in up_lr_weight:
+            up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
+
+    down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
+        down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
+    )
+
+    return down_lr_weight, mid_lr_weight, up_lr_weight
+
+
+def create_network(
+    multiplier: float,
+    network_dim: Optional[int],
+    network_alpha: Optional[float],
+    vae: AutoencoderKL,
+    text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
+    unet,
+    neuron_dropout: Optional[float] = None,
+    **kwargs,
+):
+    if network_dim is None:
+        network_dim = 4  # default
+    if network_alpha is None:
+        network_alpha = 1.0
+
+    # extract dim/alpha for conv2d, and block dim
+    conv_dim = kwargs.get("conv_dim", None)
+    conv_alpha = kwargs.get("conv_alpha", None)
+    if conv_dim is not None:
+        conv_dim = int(conv_dim)
+        if conv_alpha is None:
+            conv_alpha = 1.0
+        else:
+            conv_alpha = float(conv_alpha)
+
+    # block dim/alpha/lr
+    block_dims = kwargs.get("block_dims", None)
+    down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
+
+    # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
+    if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
+        block_alphas = kwargs.get("block_alphas", None)
+        conv_block_dims = kwargs.get("conv_block_dims", None)
+        conv_block_alphas = kwargs.get("conv_block_alphas", None)
+
+        block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
+            block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
+        )
+
+        # remove block dim/alpha without learning rate
+        block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
+            block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
+        )
+
+    else:
+        block_alphas = None
+        conv_block_dims = None
+        conv_block_alphas = None
+
+    # rank/module dropout
+    rank_dropout = kwargs.get("rank_dropout", None)
+    if rank_dropout is not None:
+        rank_dropout = float(rank_dropout)
+    module_dropout = kwargs.get("module_dropout", None)
+    if module_dropout is not None:
+        module_dropout = float(module_dropout)
+
+    # すごく引数が多いな ( ^ω^)・・・
+    network = LoRANetwork(
+        text_encoder,
+        unet,
+        multiplier=multiplier,
+        lora_dim=network_dim,
+        alpha=network_alpha,
+        dropout=neuron_dropout,
+        rank_dropout=rank_dropout,
+        module_dropout=module_dropout,
+        conv_lora_dim=conv_dim,
+        conv_alpha=conv_alpha,
+        block_dims=block_dims,
+        block_alphas=block_alphas,
+        conv_block_dims=conv_block_dims,
+        conv_block_alphas=conv_block_alphas,
+        varbose=True,
+    )
+
+    if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
+        network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
+
+    return network
+
+
+# このメソッドは外部から呼び出される可能性を考慮しておく
+# network_dim, network_alpha にはデフォルト値が入っている。
+# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
+# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
+def get_block_dims_and_alphas(
+    block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
+):
+    num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
+
+    def parse_ints(s):
+        return [int(i) for i in s.split(",")]
+
+    def parse_floats(s):
+        return [float(i) for i in s.split(",")]
+
+    # block_dimsとblock_alphasをパースする。必ず値が入る
+    if block_dims is not None:
+        block_dims = parse_ints(block_dims)
+        assert (
+            len(block_dims) == num_total_blocks
+        ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
+    else:
+        print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
+        block_dims = [network_dim] * num_total_blocks
+
+    if block_alphas is not None:
+        block_alphas = parse_floats(block_alphas)
+        assert (
+            len(block_alphas) == num_total_blocks
+        ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
+    else:
+        print(
+            f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
+        )
+        block_alphas = [network_alpha] * num_total_blocks
+
+    # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
+    if conv_block_dims is not None:
+        conv_block_dims = parse_ints(conv_block_dims)
+        assert (
+            len(conv_block_dims) == num_total_blocks
+        ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
+
+        if conv_block_alphas is not None:
+            conv_block_alphas = parse_floats(conv_block_alphas)
+            assert (
+                len(conv_block_alphas) == num_total_blocks
+            ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
+        else:
+            if conv_alpha is None:
+                conv_alpha = 1.0
+            print(
+                f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
+            )
+            conv_block_alphas = [conv_alpha] * num_total_blocks
+    else:
+        if conv_dim is not None:
+            print(
+                f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
+            )
+            conv_block_dims = [conv_dim] * num_total_blocks
+            conv_block_alphas = [conv_alpha] * num_total_blocks
+        else:
+            conv_block_dims = None
+            conv_block_alphas = None
+
+    return block_dims, block_alphas, conv_block_dims, conv_block_alphas
+
+
+# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
+def get_block_lr_weight(
+    down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
+) -> Tuple[List[float], List[float], List[float]]:
+    # パラメータ未指定時は何もせず、今までと同じ動作とする
+    if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
+        return None, None, None
+
+    max_len = LoRANetwork.NUM_OF_BLOCKS  # フルモデル相当でのup,downの層の数
+
+    def get_list(name_with_suffix) -> List[float]:
+        import math
+
+        tokens = name_with_suffix.split("+")
+        name = tokens[0]
+        base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
+
+        if name == "cosine":
+            return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
+        elif name == "sine":
+            return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
+        elif name == "linear":
+            return [i / (max_len - 1) + base_lr for i in range(max_len)]
+        elif name == "reverse_linear":
+            return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
+        elif name == "zeros":
+            return [0.0 + base_lr] * max_len
+        else:
+            print(
+                "Unknown lr_weight argument %s is used. Valid arguments:  / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
+                % (name)
+            )
+            return None
+
+    if type(down_lr_weight) == str:
+        down_lr_weight = get_list(down_lr_weight)
+    if type(up_lr_weight) == str:
+        up_lr_weight = get_list(up_lr_weight)
+
+    if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
+        print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
+        print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
+        up_lr_weight = up_lr_weight[:max_len]
+        down_lr_weight = down_lr_weight[:max_len]
+
+    if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
+        print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
+        print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
+
+        if down_lr_weight != None and len(down_lr_weight) < max_len:
+            down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
+        if up_lr_weight != None and len(up_lr_weight) < max_len:
+            up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
+
+    if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
+        print("apply block learning rate / 階層別学習率を適用します。")
+        if down_lr_weight != None:
+            down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
+            print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
+        else:
+            print("down_lr_weight: all 1.0, すべて1.0")
+
+        if mid_lr_weight != None:
+            mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
+            print("mid_lr_weight:", mid_lr_weight)
+        else:
+            print("mid_lr_weight: 1.0")
+
+        if up_lr_weight != None:
+            up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
+            print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
+        else:
+            print("up_lr_weight: all 1.0, すべて1.0")
+
+    return down_lr_weight, mid_lr_weight, up_lr_weight
+
+
+# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
+def remove_block_dims_and_alphas(
+    block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
+):
+    # set 0 to block dim without learning rate to remove the block
+    if down_lr_weight != None:
+        for i, lr in enumerate(down_lr_weight):
+            if lr == 0:
+                block_dims[i] = 0
+                if conv_block_dims is not None:
+                    conv_block_dims[i] = 0
+    if mid_lr_weight != None:
+        if mid_lr_weight == 0:
+            block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
+            if conv_block_dims is not None:
+                conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
+    if up_lr_weight != None:
+        for i, lr in enumerate(up_lr_weight):
+            if lr == 0:
+                block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
+                if conv_block_dims is not None:
+                    conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
+
+    return block_dims, block_alphas, conv_block_dims, conv_block_alphas
+
+
+# 外部から呼び出す可能性を考慮しておく
+def get_block_index(lora_name: str) -> int:
+    block_idx = -1  # invalid lora name
+
+    m = RE_UPDOWN.search(lora_name)
+    if m:
+        g = m.groups()
+        i = int(g[1])
+        j = int(g[3])
+        if g[2] == "resnets":
+            idx = 3 * i + j
+        elif g[2] == "attentions":
+            idx = 3 * i + j
+        elif g[2] == "upsamplers" or g[2] == "downsamplers":
+            idx = 3 * i + 2
+
+        if g[0] == "down":
+            block_idx = 1 + idx  # 0に該当するLoRAは存在しない
+        elif g[0] == "up":
+            block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
+
+    elif "mid_block_" in lora_name:
+        block_idx = LoRANetwork.NUM_OF_BLOCKS  # idx=12
+
+    return block_idx
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
+    if weights_sd is None:
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file, safe_open
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+    # get dim/alpha mapping
+    modules_dim = {}
+    modules_alpha = {}
+    for key, value in weights_sd.items():
+        if "." not in key:
+            continue
+
+        lora_name = key.split(".")[0]
+        if "alpha" in key:
+            modules_alpha[lora_name] = value
+        elif "lora_down" in key:
+            dim = value.size()[0]
+            modules_dim[lora_name] = dim
+            # print(lora_name, value.size(), dim)
+
+    # support old LoRA without alpha
+    for key in modules_dim.keys():
+        if key not in modules_alpha:
+            modules_alpha[key] = modules_dim[key]
+
+    module_class = LoRAInfModule if for_inference else LoRAModule
+
+    network = LoRANetwork(
+        text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
+    )
+
+    # block lr
+    down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
+    if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
+        network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
+
+    return network, weights_sd
+
+
+class LoRANetwork(torch.nn.Module):
+    NUM_OF_BLOCKS = 12  # フルモデル相当でのup,downの層の数
+
+    UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+    LORA_PREFIX_UNET = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+
+    # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
+    LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
+    LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
+
+    def __init__(
+        self,
+        text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
+        unet,
+        multiplier: float = 1.0,
+        lora_dim: int = 4,
+        alpha: float = 1,
+        dropout: Optional[float] = None,
+        rank_dropout: Optional[float] = None,
+        module_dropout: Optional[float] = None,
+        conv_lora_dim: Optional[int] = None,
+        conv_alpha: Optional[float] = None,
+        block_dims: Optional[List[int]] = None,
+        block_alphas: Optional[List[float]] = None,
+        conv_block_dims: Optional[List[int]] = None,
+        conv_block_alphas: Optional[List[float]] = None,
+        modules_dim: Optional[Dict[str, int]] = None,
+        modules_alpha: Optional[Dict[str, int]] = None,
+        module_class: Type[object] = LoRAModule,
+        varbose: Optional[bool] = False,
+    ) -> None:
+        """
+        LoRA network: すごく引数が多いが、パターンは以下の通り
+        1. lora_dimとalphaを指定
+        2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
+        3. block_dimsとblock_alphasを指定 :  Conv2d3x3には適用しない
+        4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
+        5. modules_dimとmodules_alphaを指定 (推論用)
+        """
+        super().__init__()
+        self.multiplier = multiplier
+
+        self.lora_dim = lora_dim
+        self.alpha = alpha
+        self.conv_lora_dim = conv_lora_dim
+        self.conv_alpha = conv_alpha
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+        if modules_dim is not None:
+            print(f"create LoRA network from weights")
+        elif block_dims is not None:
+            print(f"create LoRA network from block_dims")
+            print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+            print(f"block_dims: {block_dims}")
+            print(f"block_alphas: {block_alphas}")
+            if conv_block_dims is not None:
+                print(f"conv_block_dims: {conv_block_dims}")
+                print(f"conv_block_alphas: {conv_block_alphas}")
+        else:
+            print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+            print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+            if self.conv_lora_dim is not None:
+                print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
+
+        # create module instances
+        def create_modules(
+            is_unet: bool,
+            text_encoder_idx: Optional[int],  # None, 1, 2
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+        ) -> List[LoRAModule]:
+            prefix = (
+                self.LORA_PREFIX_UNET
+                if is_unet
+                else (
+                    self.LORA_PREFIX_TEXT_ENCODER
+                    if text_encoder_idx is None
+                    else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
+                )
+            )
+            loras = []
+            skipped = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "Linear"
+                        is_conv2d = child_module.__class__.__name__ == "Conv2d"
+                        is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+                        if is_linear or is_conv2d:
+                            lora_name = prefix + "." + name + "." + child_name
+                            lora_name = lora_name.replace(".", "_")
+
+                            dim = None
+                            alpha = None
+
+                            if modules_dim is not None:
+                                # モジュール指定あり
+                                if lora_name in modules_dim:
+                                    dim = modules_dim[lora_name]
+                                    alpha = modules_alpha[lora_name]
+                            elif is_unet and block_dims is not None:
+                                # U-Netでblock_dims指定あり
+                                block_idx = get_block_index(lora_name)
+                                if is_linear or is_conv2d_1x1:
+                                    dim = block_dims[block_idx]
+                                    alpha = block_alphas[block_idx]
+                                elif conv_block_dims is not None:
+                                    dim = conv_block_dims[block_idx]
+                                    alpha = conv_block_alphas[block_idx]
+                            else:
+                                # 通常、すべて対象とする
+                                if is_linear or is_conv2d_1x1:
+                                    dim = self.lora_dim
+                                    alpha = self.alpha
+                                elif self.conv_lora_dim is not None:
+                                    dim = self.conv_lora_dim
+                                    alpha = self.conv_alpha
+
+                            if dim is None or dim == 0:
+                                # skipした情報を出力
+                                if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
+                                    skipped.append(lora_name)
+                                continue
+
+                            lora = module_class(
+                                lora_name,
+                                child_module,
+                                self.multiplier,
+                                dim,
+                                alpha,
+                                dropout=dropout,
+                                rank_dropout=rank_dropout,
+                                module_dropout=module_dropout,
+                            )
+                            loras.append(lora)
+            return loras, skipped
+
+        text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+        # create LoRA for text encoder
+        # 毎回すべてのモジュールを作るのは無駄なので要検討
+        self.text_encoder_loras = []
+        skipped_te = []
+        for i, text_encoder in enumerate(text_encoders):
+            if len(text_encoders) > 1:
+                index = i + 1
+                print(f"create LoRA for Text Encoder {index}:")
+            else:
+                index = None
+                print(f"create LoRA for Text Encoder:")
+
+            text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+            self.text_encoder_loras.extend(text_encoder_loras)
+            skipped_te += skipped
+        print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+        # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
+        target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
+        if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
+            target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
+        print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+        skipped = skipped_te + skipped_un
+        if varbose and len(skipped) > 0:
+            print(
+                f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
+            )
+            for name in skipped:
+                print(f"\t{name}")
+
+        self.up_lr_weight: List[float] = None
+        self.down_lr_weight: List[float] = None
+        self.mid_lr_weight: float = None
+        self.block_lr = False
+
+        # assertion
+        names = set()
+        for lora in self.text_encoder_loras + self.unet_loras:
+            assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+            names.add(lora.lora_name)
+
+    def set_multiplier(self, multiplier):
+        self.multiplier = multiplier
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.multiplier = self.multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.apply_to()
+            self.add_module(lora.lora_name, lora)
+
+    # マージできるかどうかを返す
+    def is_mergeable(self):
+        return True
+
+    # TODO refactor to common function with apply_to
+    def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
+        apply_text_encoder = apply_unet = False
+        for key in weights_sd.keys():
+            if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
+                apply_text_encoder = True
+            elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
+                apply_unet = True
+
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            sd_for_lora = {}
+            for key in weights_sd.keys():
+                if key.startswith(lora.lora_name):
+                    sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+            lora.merge_to(sd_for_lora, dtype, device)
+
+        print(f"weights are merged")
+
+    # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
+    def set_block_lr_weight(
+        self,
+        up_lr_weight: List[float] = None,
+        mid_lr_weight: float = None,
+        down_lr_weight: List[float] = None,
+    ):
+        self.block_lr = True
+        self.down_lr_weight = down_lr_weight
+        self.mid_lr_weight = mid_lr_weight
+        self.up_lr_weight = up_lr_weight
+
+    def get_lr_weight(self, lora: LoRAModule) -> float:
+        lr_weight = 1.0
+        block_idx = get_block_index(lora.lora_name)
+        if block_idx < 0:
+            return lr_weight
+
+        if block_idx < LoRANetwork.NUM_OF_BLOCKS:
+            if self.down_lr_weight != None:
+                lr_weight = self.down_lr_weight[block_idx]
+        elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
+            if self.mid_lr_weight != None:
+                lr_weight = self.mid_lr_weight
+        elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
+            if self.up_lr_weight != None:
+                lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
+
+        return lr_weight
+
+    # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
+    def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+        self.requires_grad_(True)
+        all_params = []
+
+        def enumerate_params(loras: List[LoRAModule]):
+            params = []
+            for lora in loras:
+                # params.extend(lora.parameters())
+                params.extend(lora.get_trainable_params())
+            return params
+
+        if self.text_encoder_loras:
+            param_data = {"params": enumerate_params(self.text_encoder_loras)}
+            if text_encoder_lr is not None:
+                param_data["lr"] = text_encoder_lr
+            all_params.append(param_data)
+
+        if self.unet_loras:
+            if self.block_lr:
+                # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
+                block_idx_to_lora = {}
+                for lora in self.unet_loras:
+                    idx = get_block_index(lora.lora_name)
+                    if idx not in block_idx_to_lora:
+                        block_idx_to_lora[idx] = []
+                    block_idx_to_lora[idx].append(lora)
+
+                # blockごとにパラメータを設定する
+                for idx, block_loras in block_idx_to_lora.items():
+                    param_data = {"params": enumerate_params(block_loras)}
+
+                    if unet_lr is not None:
+                        param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
+                    elif default_lr is not None:
+                        param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
+                    if ("lr" in param_data) and (param_data["lr"] == 0):
+                        continue
+                    all_params.append(param_data)
+
+            else:
+                param_data = {"params": enumerate_params(self.unet_loras)}
+                if unet_lr is not None:
+                    param_data["lr"] = unet_lr
+                all_params.append(param_data)
+
+        return all_params
+
+    def enable_gradient_checkpointing(self):
+        # not supported
+        pass
+
+    def prepare_grad_etc(self, text_encoder, unet):
+        self.requires_grad_(True)
+
+    def on_epoch_start(self, text_encoder, unet):
+        self.train()
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+            from library import train_util
+
+            # Precalculate model hashes to save time on indexing
+            if metadata is None:
+                metadata = {}
+            model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+            metadata["sshs_model_hash"] = model_hash
+            metadata["sshs_legacy_hash"] = legacy_hash
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+    # mask is a tensor with values from 0 to 1
+    def set_region(self, sub_prompt_index, is_last_network, mask):
+        if mask.max() == 0:
+            mask = torch.ones_like(mask)
+
+        self.mask = mask
+        self.sub_prompt_index = sub_prompt_index
+        self.is_last_network = is_last_network
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.set_network(self)
+
+    def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
+        self.batch_size = batch_size
+        self.num_sub_prompts = num_sub_prompts
+        self.current_size = (height, width)
+        self.shared = shared
+
+        # create masks
+        mask = self.mask
+        mask_dic = {}
+        mask = mask.unsqueeze(0).unsqueeze(1)  # b(1),c(1),h,w
+        ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
+        dtype = ref_weight.dtype
+        device = ref_weight.device
+
+        def resize_add(mh, mw):
+            # print(mh, mw, mh * mw)
+            m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear")  # doesn't work in bf16
+            m = m.to(device, dtype=dtype)
+            mask_dic[mh * mw] = m
+
+        h = height // 8
+        w = width // 8
+        for _ in range(4):
+            resize_add(h, w)
+            if h % 2 == 1 or w % 2 == 1:  # add extra shape if h/w is not divisible by 2
+                resize_add(h + h % 2, w + w % 2)
+            h = (h + 1) // 2
+            w = (w + 1) // 2
+
+        self.mask_dic = mask_dic
+
+    def backup_weights(self):
+        # 重みのバックアップを行う
+        loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+        for lora in loras:
+            org_module = lora.org_module_ref[0]
+            if not hasattr(org_module, "_lora_org_weight"):
+                sd = org_module.state_dict()
+                org_module._lora_org_weight = sd["weight"].detach().clone()
+                org_module._lora_restored = True
+
+    def restore_weights(self):
+        # 重みのリストアを行う
+        loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+        for lora in loras:
+            org_module = lora.org_module_ref[0]
+            if not org_module._lora_restored:
+                sd = org_module.state_dict()
+                sd["weight"] = org_module._lora_org_weight
+                org_module.load_state_dict(sd)
+                org_module._lora_restored = True
+
+    def pre_calculation(self):
+        # 事前計算を行う
+        loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+        for lora in loras:
+            org_module = lora.org_module_ref[0]
+            sd = org_module.state_dict()
+
+            org_weight = sd["weight"]
+            lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+            sd["weight"] = org_weight + lora_weight
+            assert sd["weight"].shape == org_weight.shape
+            org_module.load_state_dict(sd)
+
+            org_module._lora_restored = False
+            lora.enabled = False
+
+    def apply_max_norm_regularization(self, max_norm_value, device):
+        downkeys = []
+        upkeys = []
+        alphakeys = []
+        norms = []
+        keys_scaled = 0
+
+        state_dict = self.state_dict()
+        for key in state_dict.keys():
+            if "lora_down" in key and "weight" in key:
+                downkeys.append(key)
+                upkeys.append(key.replace("lora_down", "lora_up"))
+                alphakeys.append(key.replace("lora_down.weight", "alpha"))
+
+        for i in range(len(downkeys)):
+            down = state_dict[downkeys[i]].to(device)
+            up = state_dict[upkeys[i]].to(device)
+            alpha = state_dict[alphakeys[i]].to(device)
+            dim = down.shape[0]
+            scale = alpha / dim
+
+            if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+                updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+            elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+                updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
+            else:
+                updown = up @ down
+
+            updown *= scale
+
+            norm = updown.norm().clamp(min=max_norm_value / 2)
+            desired = torch.clamp(norm, max=max_norm_value)
+            ratio = desired.cpu() / norm.cpu()
+            sqrt_ratio = ratio**0.5
+            if ratio != 1:
+                keys_scaled += 1
+                state_dict[upkeys[i]] *= sqrt_ratio
+                state_dict[downkeys[i]] *= sqrt_ratio
+            scalednorm = updown.norm() * ratio
+            norms.append(scalednorm.item())
+
+        return keys_scaled, sum(norms) / len(norms), max(norms)
diff --git a/external/llite/networks/lora_interrogator.py b/external/llite/networks/lora_interrogator.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc066fd1525d873e280c970707d24d0061db9e2
--- /dev/null
+++ b/external/llite/networks/lora_interrogator.py
@@ -0,0 +1,139 @@
+
+
+from tqdm import tqdm
+from library import model_util
+import library.train_util as train_util
+import argparse
+from transformers import CLIPTokenizer
+import torch
+
+import library.model_util as model_util
+import lora
+
+TOKENIZER_PATH = "openai/clip-vit-large-patch14"
+V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2"     # ここからtokenizerだけ使う
+
+DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def interrogate(args):
+  weights_dtype = torch.float16
+
+  # いろいろ準備する
+  print(f"loading SD model: {args.sd_model}")
+  args.pretrained_model_name_or_path = args.sd_model
+  args.vae = None
+  text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
+
+  print(f"loading LoRA: {args.model}")
+  network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
+
+  # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
+  has_te_weight = False
+  for key in weights_sd.keys():
+    if 'lora_te' in key:
+      has_te_weight = True
+      break
+  if not has_te_weight:
+    print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
+    return
+  del vae
+
+  print("loading tokenizer")
+  if args.v2:
+    tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
+  else:
+    tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)  # , model_max_length=max_token_length + 2)
+
+  text_encoder.to(DEVICE, dtype=weights_dtype)
+  text_encoder.eval()
+  unet.to(DEVICE, dtype=weights_dtype)
+  unet.eval()               # U-Netは呼び出さないので不要だけど
+
+  # トークンをひとつひとつ当たっていく
+  token_id_start = 0
+  token_id_end = max(tokenizer.all_special_ids)
+  print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
+
+  def get_all_embeddings(text_encoder):
+    embs = []
+    with torch.no_grad():
+      for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
+        batch = []
+        for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
+          tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
+          # tokens = [tid]                                                    # こちらは結果がいまひとつ
+          batch.append(tokens)
+
+        # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu")  # bos/eosも含めたほうが差が出るようだ [:, 1]
+        # clip skip対応
+        batch = torch.tensor(batch).to(DEVICE)
+        if args.clip_skip is None:
+          encoder_hidden_states = text_encoder(batch)[0]
+        else:
+          enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
+          encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
+          encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
+        encoder_hidden_states = encoder_hidden_states.to("cpu")
+
+        embs.extend(encoder_hidden_states)
+    return torch.stack(embs)
+
+  print("get original text encoder embeddings.")
+  orig_embs = get_all_embeddings(text_encoder)
+
+  network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
+  info = network.load_state_dict(weights_sd, strict=False)
+  print(f"Loading LoRA weights: {info}")
+
+  network.to(DEVICE, dtype=weights_dtype)
+  network.eval()
+
+  del unet
+
+  print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
+  print("get text encoder embeddings with lora.")
+  lora_embs = get_all_embeddings(text_encoder)
+
+  # 比べる:とりあえず単純に差分の絶対値で
+  print("comparing...")
+  diffs = {}
+  for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
+    diff = torch.mean(torch.abs(orig_emb - lora_emb))
+    # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1))       # うまく検出できない
+    diff = float(diff.detach().to('cpu').numpy())
+    diffs[token_id_start + i] = diff
+
+  diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
+
+  # 結果を表示する
+  print("top 100:")
+  for i, (token, diff) in enumerate(diffs_sorted[:100]):
+    # if diff < 1e-6:
+    #   break
+    string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
+    print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+  parser = argparse.ArgumentParser()
+
+  parser.add_argument("--v2", action='store_true',
+                      help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
+  parser.add_argument("--sd_model", type=str, default=None,
+                      help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
+  parser.add_argument("--model", type=str, default=None,
+                      help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
+  parser.add_argument("--batch_size", type=int, default=16,
+                      help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
+  parser.add_argument("--clip_skip", type=int, default=None,
+                      help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
+
+  return parser
+
+
+if __name__ == '__main__':
+  parser = setup_parser()
+
+  args = parser.parse_args()
+  interrogate(args)
diff --git a/external/llite/networks/merge_lora.py b/external/llite/networks/merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..71492621ef7222444d851d42ea689e2771e218df
--- /dev/null
+++ b/external/llite/networks/merge_lora.py
@@ -0,0 +1,357 @@
+import math
+import argparse
+import os
+import time
+import torch
+from safetensors.torch import load_file, save_file
+from library import sai_model_spec, train_util
+import library.model_util as model_util
+import lora
+
+
+def load_state_dict(file_name, dtype):
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        sd = load_file(file_name)
+        metadata = train_util.load_metadata_from_safetensors(file_name)
+    else:
+        sd = torch.load(file_name, map_location="cpu")
+        metadata = {}
+
+    for key in list(sd.keys()):
+        if type(sd[key]) == torch.Tensor:
+            sd[key] = sd[key].to(dtype)
+
+    return sd, metadata
+
+
+def save_to_file(file_name, model, state_dict, dtype, metadata):
+    if dtype is not None:
+        for key in list(state_dict.keys()):
+            if type(state_dict[key]) == torch.Tensor:
+                state_dict[key] = state_dict[key].to(dtype)
+
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        save_file(model, file_name, metadata=metadata)
+    else:
+        torch.save(model, file_name)
+
+
+def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
+    text_encoder.to(merge_dtype)
+    unet.to(merge_dtype)
+
+    # create module map
+    name_to_module = {}
+    for i, root_module in enumerate([text_encoder, unet]):
+        if i == 0:
+            prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
+            target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
+        else:
+            prefix = lora.LoRANetwork.LORA_PREFIX_UNET
+            target_replace_modules = (
+                lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+            )
+
+        for name, module in root_module.named_modules():
+            if module.__class__.__name__ in target_replace_modules:
+                for child_name, child_module in module.named_modules():
+                    if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
+                        lora_name = prefix + "." + name + "." + child_name
+                        lora_name = lora_name.replace(".", "_")
+                        name_to_module[lora_name] = child_module
+
+    for model, ratio in zip(models, ratios):
+        print(f"loading: {model}")
+        lora_sd, _ = load_state_dict(model, merge_dtype)
+
+        print(f"merging...")
+        for key in lora_sd.keys():
+            if "lora_down" in key:
+                up_key = key.replace("lora_down", "lora_up")
+                alpha_key = key[: key.index("lora_down")] + "alpha"
+
+                # find original module for this lora
+                module_name = ".".join(key.split(".")[:-2])  # remove trailing ".lora_down.weight"
+                if module_name not in name_to_module:
+                    print(f"no module found for LoRA weight: {key}")
+                    continue
+                module = name_to_module[module_name]
+                # print(f"apply {key} to {module}")
+
+                down_weight = lora_sd[key]
+                up_weight = lora_sd[up_key]
+
+                dim = down_weight.size()[0]
+                alpha = lora_sd.get(alpha_key, dim)
+                scale = alpha / dim
+
+                # W <- W + U * D
+                weight = module.weight
+                if len(weight.size()) == 2:
+                    # linear
+                    if len(up_weight.size()) == 4:  # use linear projection mismatch
+                        up_weight = up_weight.squeeze(3).squeeze(2)
+                        down_weight = down_weight.squeeze(3).squeeze(2)
+                    weight = weight + ratio * (up_weight @ down_weight) * scale
+                elif down_weight.size()[2:4] == (1, 1):
+                    # conv2d 1x1
+                    weight = (
+                        weight
+                        + ratio
+                        * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                        * scale
+                    )
+                else:
+                    # conv2d 3x3
+                    conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+                    # print(conved.size(), weight.size(), module.stride, module.padding)
+                    weight = weight + ratio * conved * scale
+
+                module.weight = torch.nn.Parameter(weight)
+
+
+def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
+    base_alphas = {}  # alpha for merged model
+    base_dims = {}
+
+    merged_sd = {}
+    v2 = None
+    base_model = None
+    for model, ratio in zip(models, ratios):
+        print(f"loading: {model}")
+        lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
+
+        if lora_metadata is not None:
+            if v2 is None:
+                v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None)  # return string
+            if base_model is None:
+                base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
+
+        # get alpha and dim
+        alphas = {}  # alpha for current model
+        dims = {}  # dims for current model
+        for key in lora_sd.keys():
+            if "alpha" in key:
+                lora_module_name = key[: key.rfind(".alpha")]
+                alpha = float(lora_sd[key].detach().numpy())
+                alphas[lora_module_name] = alpha
+                if lora_module_name not in base_alphas:
+                    base_alphas[lora_module_name] = alpha
+            elif "lora_down" in key:
+                lora_module_name = key[: key.rfind(".lora_down")]
+                dim = lora_sd[key].size()[0]
+                dims[lora_module_name] = dim
+                if lora_module_name not in base_dims:
+                    base_dims[lora_module_name] = dim
+
+        for lora_module_name in dims.keys():
+            if lora_module_name not in alphas:
+                alpha = dims[lora_module_name]
+                alphas[lora_module_name] = alpha
+                if lora_module_name not in base_alphas:
+                    base_alphas[lora_module_name] = alpha
+
+        print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
+
+        # merge
+        print(f"merging...")
+        for key in lora_sd.keys():
+            if "alpha" in key:
+                continue
+            if "lora_up" in key and concat:
+                concat_dim = 1
+            elif "lora_down" in key and concat:
+                concat_dim = 0
+            else:
+                concat_dim = None
+
+            lora_module_name = key[: key.rfind(".lora_")]
+
+            base_alpha = base_alphas[lora_module_name]
+            alpha = alphas[lora_module_name]
+
+            scale = math.sqrt(alpha / base_alpha) * ratio
+            scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
+
+            if key in merged_sd:
+                assert (
+                    merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
+                ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
+                if concat_dim is not None:
+                    merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
+                else:
+                    merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
+            else:
+                merged_sd[key] = lora_sd[key] * scale
+
+    # set alpha to sd
+    for lora_module_name, alpha in base_alphas.items():
+        key = lora_module_name + ".alpha"
+        merged_sd[key] = torch.tensor(alpha)
+        if shuffle:
+            key_down = lora_module_name + ".lora_down.weight"
+            key_up = lora_module_name + ".lora_up.weight"
+            dim = merged_sd[key_down].shape[0]
+            perm = torch.randperm(dim)
+            merged_sd[key_down] = merged_sd[key_down][perm]
+            merged_sd[key_up] = merged_sd[key_up][:,perm]
+
+    print("merged model")
+    print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
+
+    # check all dims are same
+    dims_list = list(set(base_dims.values()))
+    alphas_list = list(set(base_alphas.values()))
+    all_same_dims = True
+    all_same_alphas = True
+    for dims in dims_list:
+        if dims != dims_list[0]:
+            all_same_dims = False
+            break
+    for alphas in alphas_list:
+        if alphas != alphas_list[0]:
+            all_same_alphas = False
+            break
+
+    # build minimum metadata
+    dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
+    alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
+    metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
+
+    return merged_sd, metadata, v2 == "True"
+
+
+def merge(args):
+    assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
+
+    def str_to_dtype(p):
+        if p == "float":
+            return torch.float
+        if p == "fp16":
+            return torch.float16
+        if p == "bf16":
+            return torch.bfloat16
+        return None
+
+    merge_dtype = str_to_dtype(args.precision)
+    save_dtype = str_to_dtype(args.save_precision)
+    if save_dtype is None:
+        save_dtype = merge_dtype
+
+    if args.sd_model is not None:
+        print(f"loading SD model: {args.sd_model}")
+
+        text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
+
+        merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
+
+        if args.no_metadata:
+            sai_metadata = None
+        else:
+            merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
+            title = os.path.splitext(os.path.basename(args.save_to))[0]
+            sai_metadata = sai_model_spec.build_metadata(
+                None,
+                args.v2,
+                args.v2,
+                False,
+                False,
+                False,
+                time.time(),
+                title=title,
+                merged_from=merged_from,
+                is_stable_diffusion_ckpt=True,
+            )
+            if args.v2:
+                # TODO read sai modelspec
+                print(
+                    "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
+                )
+
+        print(f"saving SD model to: {args.save_to}")
+        model_util.save_stable_diffusion_checkpoint(
+            args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
+        )
+    else:
+        state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
+
+        print(f"calculating hashes and creating metadata...")
+
+        model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+        metadata["sshs_model_hash"] = model_hash
+        metadata["sshs_legacy_hash"] = legacy_hash
+
+        if not args.no_metadata:
+            merged_from = sai_model_spec.build_merged_from(args.models)
+            title = os.path.splitext(os.path.basename(args.save_to))[0]
+            sai_metadata = sai_model_spec.build_metadata(
+                state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from
+            )
+            if v2:
+                # TODO read sai modelspec
+                print(
+                    "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
+                )
+            metadata.update(sai_metadata)
+
+        print(f"saving model to: {args.save_to}")
+        save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
+    parser.add_argument(
+        "--save_precision",
+        type=str,
+        default=None,
+        choices=[None, "float", "fp16", "bf16"],
+        help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
+    )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        default="float",
+        choices=["float", "fp16", "bf16"],
+        help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
+    )
+    parser.add_argument(
+        "--sd_model",
+        type=str,
+        default=None,
+        help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
+    )
+    parser.add_argument(
+        "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
+    )
+    parser.add_argument(
+        "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
+    )
+    parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
+    parser.add_argument(
+        "--no_metadata",
+        action="store_true",
+        help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+        + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
+    )
+    parser.add_argument(
+        "--concat",
+        action="store_true",
+        help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+        + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
+    )
+    parser.add_argument(
+        "--shuffle",
+        action="store_true",
+        help="shuffle lora weight./ "
+        + "LoRAの重みをシャッフルする",
+    )
+    
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    merge(args)
diff --git a/external/llite/networks/merge_lora_old.py b/external/llite/networks/merge_lora_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffd6b2b4095ec68a419e925908fe82ff69c5afad
--- /dev/null
+++ b/external/llite/networks/merge_lora_old.py
@@ -0,0 +1,185 @@
+
+
+import argparse
+import os
+import torch
+from safetensors.torch import load_file, save_file
+import library.model_util as model_util
+import lora
+
+
+def load_state_dict(file_name, dtype):
+  if os.path.splitext(file_name)[1] == '.safetensors':
+    sd = load_file(file_name)
+  else:
+    sd = torch.load(file_name, map_location='cpu')
+  for key in list(sd.keys()):
+    if type(sd[key]) == torch.Tensor:
+      sd[key] = sd[key].to(dtype)
+  return sd
+
+
+def save_to_file(file_name, model, state_dict, dtype):
+  if dtype is not None:
+    for key in list(state_dict.keys()):
+      if type(state_dict[key]) == torch.Tensor:
+        state_dict[key] = state_dict[key].to(dtype)
+
+  if os.path.splitext(file_name)[1] == '.safetensors':
+    save_file(model, file_name)
+  else:
+    torch.save(model, file_name)
+
+
+def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
+  text_encoder.to(merge_dtype)
+  unet.to(merge_dtype)
+
+  # create module map
+  name_to_module = {}
+  for i, root_module in enumerate([text_encoder, unet]):
+    if i == 0:
+      prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
+      target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
+    else:
+      prefix = lora.LoRANetwork.LORA_PREFIX_UNET
+      target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
+
+    for name, module in root_module.named_modules():
+      if module.__class__.__name__ in target_replace_modules:
+        for child_name, child_module in module.named_modules():
+          if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
+            lora_name = prefix + '.' + name + '.' + child_name
+            lora_name = lora_name.replace('.', '_')
+            name_to_module[lora_name] = child_module
+
+  for model, ratio in zip(models, ratios):
+    print(f"loading: {model}")
+    lora_sd = load_state_dict(model, merge_dtype)
+
+    print(f"merging...")
+    for key in lora_sd.keys():
+      if "lora_down" in key:
+        up_key = key.replace("lora_down", "lora_up")
+        alpha_key = key[:key.index("lora_down")] + 'alpha'
+
+        # find original module for this lora
+        module_name = '.'.join(key.split('.')[:-2])               # remove trailing ".lora_down.weight"
+        if module_name not in name_to_module:
+          print(f"no module found for LoRA weight: {key}")
+          continue
+        module = name_to_module[module_name]
+        # print(f"apply {key} to {module}")
+
+        down_weight = lora_sd[key]
+        up_weight = lora_sd[up_key]
+
+        dim = down_weight.size()[0]
+        alpha = lora_sd.get(alpha_key, dim)
+        scale = alpha / dim
+
+        # W <- W + U * D
+        weight = module.weight
+        if len(weight.size()) == 2:
+          # linear
+          weight = weight + ratio * (up_weight @ down_weight) * scale
+        else:
+          # conv2d
+          weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
+
+        module.weight = torch.nn.Parameter(weight)
+
+
+def merge_lora_models(models, ratios, merge_dtype):
+  merged_sd = {}
+
+  alpha = None
+  dim = None
+  for model, ratio in zip(models, ratios):
+    print(f"loading: {model}")
+    lora_sd = load_state_dict(model, merge_dtype)
+
+    print(f"merging...")
+    for key in lora_sd.keys():
+      if 'alpha' in key:
+        if key in merged_sd:
+          assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
+        else:
+          alpha = lora_sd[key].detach().numpy()
+          merged_sd[key] = lora_sd[key]
+      else:
+        if key in merged_sd:
+          assert merged_sd[key].size() == lora_sd[key].size(
+          ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
+          merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
+        else:
+          if "lora_down" in key:
+            dim = lora_sd[key].size()[0]
+          merged_sd[key] = lora_sd[key] * ratio
+
+  print(f"dim (rank): {dim}, alpha: {alpha}")
+  if alpha is None:
+    alpha = dim
+
+  return merged_sd, dim, alpha
+
+
+def merge(args):
+  assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
+
+  def str_to_dtype(p):
+    if p == 'float':
+      return torch.float
+    if p == 'fp16':
+      return torch.float16
+    if p == 'bf16':
+      return torch.bfloat16
+    return None
+
+  merge_dtype = str_to_dtype(args.precision)
+  save_dtype = str_to_dtype(args.save_precision)
+  if save_dtype is None:
+    save_dtype = merge_dtype
+
+  if args.sd_model is not None:
+    print(f"loading SD model: {args.sd_model}")
+
+    text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
+
+    merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
+
+    print(f"\nsaving SD model to: {args.save_to}")
+    model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
+                                                args.sd_model, 0, 0, save_dtype, vae)
+  else:
+    state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
+
+    print(f"\nsaving model to: {args.save_to}")
+    save_to_file(args.save_to, state_dict, state_dict, save_dtype)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--v2", action='store_true',
+                      help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
+  parser.add_argument("--save_precision", type=str, default=None,
+                      choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
+  parser.add_argument("--precision", type=str, default="float",
+                      choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
+  parser.add_argument("--sd_model", type=str, default=None,
+                      help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
+  parser.add_argument("--save_to", type=str, default=None,
+                      help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
+  parser.add_argument("--models", type=str, nargs='*',
+                      help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
+  parser.add_argument("--ratios", type=float, nargs='*',
+                      help="ratios for each model / それぞれのLoRAモデルの比率")
+
+  return parser
+
+
+if __name__ == '__main__':
+  parser = setup_parser()
+
+  args = parser.parse_args()
+  merge(args)
diff --git a/external/llite/networks/oft.py b/external/llite/networks/oft.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d088f8779ca857ee65d831f428442a13cd350df
--- /dev/null
+++ b/external/llite/networks/oft.py
@@ -0,0 +1,430 @@
+# OFT network module
+
+import math
+import os
+from typing import Dict, List, Optional, Tuple, Type, Union
+from diffusers import AutoencoderKL
+from transformers import CLIPTextModel
+import numpy as np
+import torch
+import re
+
+
+RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
+
+
+class OFTModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        oft_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        dim=4,
+        alpha=1,
+    ):
+        """
+        dim -> num blocks
+        alpha -> constraint
+        """
+        super().__init__()
+        self.oft_name = oft_name
+
+        self.num_blocks = dim
+
+        if "Linear" in org_module.__class__.__name__:
+            out_dim = org_module.out_features
+        elif "Conv" in org_module.__class__.__name__:
+            out_dim = org_module.out_channels
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().numpy()
+        self.constraint = alpha * out_dim
+        self.register_buffer("alpha", torch.tensor(alpha))
+
+        self.block_size = out_dim // self.num_blocks
+        self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
+
+        self.out_dim = out_dim
+        self.shape = org_module.weight.shape
+
+        self.multiplier = multiplier
+        self.org_module = [org_module]  # moduleにならないようにlistに入れる
+
+    def apply_to(self):
+        self.org_forward = self.org_module[0].forward
+        self.org_module[0].forward = self.forward
+
+    def get_weight(self, multiplier=None):
+        if multiplier is None:
+            multiplier = self.multiplier
+
+        block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
+        norm_Q = torch.norm(block_Q.flatten())
+        new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
+        block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
+        I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
+        block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
+
+        block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
+        R = torch.block_diag(*block_R_weighted)
+
+        return R
+
+    def forward(self, x, scale=None):
+        x = self.org_forward(x)
+        if self.multiplier == 0.0:
+            return x
+
+        R = self.get_weight().to(x.device, dtype=x.dtype)
+        if x.dim() == 4:
+            x = x.permute(0, 2, 3, 1)
+            x = torch.matmul(x, R)
+            x = x.permute(0, 3, 1, 2)
+        else:
+            x = torch.matmul(x, R)
+        return x
+
+
+class OFTInfModule(OFTModule):
+    def __init__(
+        self,
+        oft_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        dim=4,
+        alpha=1,
+        **kwargs,
+    ):
+        # no dropout for inference
+        super().__init__(oft_name, org_module, multiplier, dim, alpha)
+        self.enabled = True
+        self.network: OFTNetwork = None
+
+    def set_network(self, network):
+        self.network = network
+
+    def forward(self, x, scale=None):
+        if not self.enabled:
+            return self.org_forward(x)
+        return super().forward(x, scale)
+
+    def merge_to(self, multiplier=None, sign=1):
+        R = self.get_weight(multiplier) * sign
+
+        # get org weight
+        org_sd = self.org_module[0].state_dict()
+        org_weight = org_sd["weight"]
+        R = R.to(org_weight.device, dtype=org_weight.dtype)
+
+        if org_weight.dim() == 4:
+            weight = torch.einsum("oihw, op -> pihw", org_weight, R)
+        else:
+            weight = torch.einsum("oi, op -> pi", org_weight, R)
+
+        # set weight to org_module
+        org_sd["weight"] = weight
+        self.org_module[0].load_state_dict(org_sd)
+
+
+def create_network(
+    multiplier: float,
+    network_dim: Optional[int],
+    network_alpha: Optional[float],
+    vae: AutoencoderKL,
+    text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
+    unet,
+    neuron_dropout: Optional[float] = None,
+    **kwargs,
+):
+    if network_dim is None:
+        network_dim = 4  # default
+    if network_alpha is None:
+        network_alpha = 1.0
+
+    enable_all_linear = kwargs.get("enable_all_linear", None)
+    enable_conv = kwargs.get("enable_conv", None)
+    if enable_all_linear is not None:
+        enable_all_linear = bool(enable_all_linear)
+    if enable_conv is not None:
+        enable_conv = bool(enable_conv)
+
+    network = OFTNetwork(
+        text_encoder,
+        unet,
+        multiplier=multiplier,
+        dim=network_dim,
+        alpha=network_alpha,
+        enable_all_linear=enable_all_linear,
+        enable_conv=enable_conv,
+        varbose=True,
+    )
+    return network
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
+    if weights_sd is None:
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file, safe_open
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+    # check dim, alpha and if weights have for conv2d
+    dim = None
+    alpha = None
+    has_conv2d = None
+    all_linear = None
+    for name, param in weights_sd.items():
+        if name.endswith(".alpha"):
+            if alpha is None:
+                alpha = param.item()
+        else:
+            if dim is None:
+                dim = param.size()[0]
+            if has_conv2d is None and param.dim() == 4:
+                has_conv2d = True
+            if all_linear is None:
+                if param.dim() == 3 and "attn" not in name:
+                    all_linear = True
+        if dim is not None and alpha is not None and has_conv2d is not None:
+            break
+    if has_conv2d is None:
+        has_conv2d = False
+    if all_linear is None:
+        all_linear = False
+
+    module_class = OFTInfModule if for_inference else OFTModule
+    network = OFTNetwork(
+        text_encoder,
+        unet,
+        multiplier=multiplier,
+        dim=dim,
+        alpha=alpha,
+        enable_all_linear=all_linear,
+        enable_conv=has_conv2d,
+        module_class=module_class,
+    )
+    return network, weights_sd
+
+
+class OFTNetwork(torch.nn.Module):
+    UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
+    UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
+    UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+    OFT_PREFIX_UNET = "oft_unet"  # これ変えないほうがいいかな
+
+    def __init__(
+        self,
+        text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
+        unet,
+        multiplier: float = 1.0,
+        dim: int = 4,
+        alpha: float = 1,
+        enable_all_linear: Optional[bool] = False,
+        enable_conv: Optional[bool] = False,
+        module_class: Type[object] = OFTModule,
+        varbose: Optional[bool] = False,
+    ) -> None:
+        super().__init__()
+        self.multiplier = multiplier
+
+        self.dim = dim
+        self.alpha = alpha
+
+        print(
+            f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
+        )
+
+        # create module instances
+        def create_modules(
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+        ) -> List[OFTModule]:
+            prefix = self.OFT_PREFIX_UNET
+            ofts = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = "Linear" in child_module.__class__.__name__
+                        is_conv2d = "Conv2d" in child_module.__class__.__name__
+                        is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+                        if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
+                            oft_name = prefix + "." + name + "." + child_name
+                            oft_name = oft_name.replace(".", "_")
+                            # print(oft_name)
+
+                            oft = module_class(
+                                oft_name,
+                                child_module,
+                                self.multiplier,
+                                dim,
+                                alpha,
+                            )
+                            ofts.append(oft)
+            return ofts
+
+        # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
+        if enable_all_linear:
+            target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
+        else:
+            target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
+        if enable_conv:
+            target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+
+        self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
+        print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
+
+        # assertion
+        names = set()
+        for oft in self.unet_ofts:
+            assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
+            names.add(oft.oft_name)
+
+    def set_multiplier(self, multiplier):
+        self.multiplier = multiplier
+        for oft in self.unet_ofts:
+            oft.multiplier = self.multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+        assert apply_unet, "apply_unet must be True"
+
+        for oft in self.unet_ofts:
+            oft.apply_to()
+            self.add_module(oft.oft_name, oft)
+
+    # マージできるかどうかを返す
+    def is_mergeable(self):
+        return True
+
+    # TODO refactor to common function with apply_to
+    def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
+        print("enable OFT for U-Net")
+
+        for oft in self.unet_ofts:
+            sd_for_lora = {}
+            for key in weights_sd.keys():
+                if key.startswith(oft.oft_name):
+                    sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
+            oft.load_state_dict(sd_for_lora, False)
+            oft.merge_to()
+
+        print(f"weights are merged")
+
+    # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
+    def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+        self.requires_grad_(True)
+        all_params = []
+
+        def enumerate_params(ofts):
+            params = []
+            for oft in ofts:
+                params.extend(oft.parameters())
+
+            # print num of params
+            num_params = 0
+            for p in params:
+                num_params += p.numel()
+            print(f"OFT params: {num_params}")
+            return params
+
+        param_data = {"params": enumerate_params(self.unet_ofts)}
+        if unet_lr is not None:
+            param_data["lr"] = unet_lr
+        all_params.append(param_data)
+
+        return all_params
+
+    def enable_gradient_checkpointing(self):
+        # not supported
+        pass
+
+    def prepare_grad_etc(self, text_encoder, unet):
+        self.requires_grad_(True)
+
+    def on_epoch_start(self, text_encoder, unet):
+        self.train()
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+            from library import train_util
+
+            # Precalculate model hashes to save time on indexing
+            if metadata is None:
+                metadata = {}
+            model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+            metadata["sshs_model_hash"] = model_hash
+            metadata["sshs_legacy_hash"] = legacy_hash
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+    def backup_weights(self):
+        # 重みのバックアップを行う
+        ofts: List[OFTInfModule] = self.unet_ofts
+        for oft in ofts:
+            org_module = oft.org_module[0]
+            if not hasattr(org_module, "_lora_org_weight"):
+                sd = org_module.state_dict()
+                org_module._lora_org_weight = sd["weight"].detach().clone()
+                org_module._lora_restored = True
+
+    def restore_weights(self):
+        # 重みのリストアを行う
+        ofts: List[OFTInfModule] = self.unet_ofts
+        for oft in ofts:
+            org_module = oft.org_module[0]
+            if not org_module._lora_restored:
+                sd = org_module.state_dict()
+                sd["weight"] = org_module._lora_org_weight
+                org_module.load_state_dict(sd)
+                org_module._lora_restored = True
+
+    def pre_calculation(self):
+        # 事前計算を行う
+        ofts: List[OFTInfModule] = self.unet_ofts
+        for oft in ofts:
+            org_module = oft.org_module[0]
+            oft.merge_to()
+            # sd = org_module.state_dict()
+            # org_weight = sd["weight"]
+            # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+            # sd["weight"] = org_weight + lora_weight
+            # assert sd["weight"].shape == org_weight.shape
+            # org_module.load_state_dict(sd)
+
+            org_module._lora_restored = False
+            oft.enabled = False
diff --git a/external/llite/networks/resize_lora.py b/external/llite/networks/resize_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..03fc545e7d8f0612b386057b5efae1c69173225b
--- /dev/null
+++ b/external/llite/networks/resize_lora.py
@@ -0,0 +1,362 @@
+# Convert LoRA to different rank approximation (should only be used to go to lower rank)
+# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
+# Thanks to cloneofsimo
+
+import argparse
+import torch
+from safetensors.torch import load_file, save_file, safe_open
+from tqdm import tqdm
+from library import train_util, model_util
+import numpy as np
+
+MIN_SV = 1e-6
+
+# Model save and load functions
+
+def load_state_dict(file_name, dtype):
+  if model_util.is_safetensors(file_name):
+    sd = load_file(file_name)
+    with safe_open(file_name, framework="pt") as f:
+      metadata = f.metadata()
+  else:
+    sd = torch.load(file_name, map_location='cpu')
+    metadata = None
+
+  for key in list(sd.keys()):
+    if type(sd[key]) == torch.Tensor:
+      sd[key] = sd[key].to(dtype)
+
+  return sd, metadata
+
+
+def save_to_file(file_name, model, state_dict, dtype, metadata):
+  if dtype is not None:
+    for key in list(state_dict.keys()):
+      if type(state_dict[key]) == torch.Tensor:
+        state_dict[key] = state_dict[key].to(dtype)
+
+  if model_util.is_safetensors(file_name):
+    save_file(model, file_name, metadata)
+  else:
+    torch.save(model, file_name)
+
+
+# Indexing functions
+
+def index_sv_cumulative(S, target):
+  original_sum = float(torch.sum(S))
+  cumulative_sums = torch.cumsum(S, dim=0)/original_sum
+  index = int(torch.searchsorted(cumulative_sums, target)) + 1
+  index = max(1, min(index, len(S)-1))
+
+  return index
+
+
+def index_sv_fro(S, target):
+  S_squared = S.pow(2)
+  s_fro_sq = float(torch.sum(S_squared))
+  sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
+  index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
+  index = max(1, min(index, len(S)-1))
+
+  return index
+
+
+def index_sv_ratio(S, target):
+  max_sv = S[0]
+  min_sv = max_sv/target
+  index = int(torch.sum(S > min_sv).item())
+  index = max(1, min(index, len(S)-1))
+
+  return index
+
+
+# Modified from Kohaku-blueleaf's extract/merge functions
+def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
+    out_size, in_size, kernel_size, _ = weight.size()
+    U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
+    
+    param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
+    lora_rank = param_dict["new_rank"]
+
+    U = U[:, :lora_rank]
+    S = S[:lora_rank]
+    U = U @ torch.diag(S)
+    Vh = Vh[:lora_rank, :]
+
+    param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
+    param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
+    del U, S, Vh, weight
+    return param_dict
+
+
+def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
+    out_size, in_size = weight.size()
+    
+    U, S, Vh = torch.linalg.svd(weight.to(device))
+    
+    param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
+    lora_rank = param_dict["new_rank"]
+    
+    U = U[:, :lora_rank]
+    S = S[:lora_rank]
+    U = U @ torch.diag(S)
+    Vh = Vh[:lora_rank, :]
+    
+    param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
+    param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
+    del U, S, Vh, weight
+    return param_dict
+
+
+def merge_conv(lora_down, lora_up, device):
+    in_rank, in_size, kernel_size, k_ = lora_down.shape
+    out_size, out_rank, _, _ = lora_up.shape
+    assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
+    
+    lora_down = lora_down.to(device)
+    lora_up = lora_up.to(device)
+
+    merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
+    weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
+    del lora_up, lora_down
+    return weight
+
+
+def merge_linear(lora_down, lora_up, device):
+    in_rank, in_size = lora_down.shape
+    out_size, out_rank = lora_up.shape
+    assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
+    
+    lora_down = lora_down.to(device)
+    lora_up = lora_up.to(device)
+    
+    weight = lora_up @ lora_down
+    del lora_up, lora_down
+    return weight
+  
+
+# Calculate new rank
+
+def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
+    param_dict = {}
+
+    if dynamic_method=="sv_ratio":
+        # Calculate new dim and alpha based off ratio
+        new_rank = index_sv_ratio(S, dynamic_param) + 1
+        new_alpha = float(scale*new_rank)
+
+    elif dynamic_method=="sv_cumulative":
+        # Calculate new dim and alpha based off cumulative sum
+        new_rank = index_sv_cumulative(S, dynamic_param) + 1
+        new_alpha = float(scale*new_rank)
+
+    elif dynamic_method=="sv_fro":
+        # Calculate new dim and alpha based off sqrt sum of squares
+        new_rank = index_sv_fro(S, dynamic_param) + 1
+        new_alpha = float(scale*new_rank)
+    else:
+        new_rank = rank
+        new_alpha = float(scale*new_rank)
+
+    
+    if S[0] <= MIN_SV: # Zero matrix, set dim to 1
+        new_rank = 1
+        new_alpha = float(scale*new_rank)
+    elif new_rank > rank: # cap max rank at rank
+        new_rank = rank
+        new_alpha = float(scale*new_rank)
+
+
+    # Calculate resize info
+    s_sum = torch.sum(torch.abs(S))
+    s_rank = torch.sum(torch.abs(S[:new_rank]))
+    
+    S_squared = S.pow(2)
+    s_fro = torch.sqrt(torch.sum(S_squared))
+    s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
+    fro_percent = float(s_red_fro/s_fro)
+
+    param_dict["new_rank"] = new_rank
+    param_dict["new_alpha"] = new_alpha
+    param_dict["sum_retained"] = (s_rank)/s_sum
+    param_dict["fro_retained"] = fro_percent
+    param_dict["max_ratio"] = S[0]/S[new_rank - 1]
+
+    return param_dict
+
+
+def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
+  network_alpha = None
+  network_dim = None
+  verbose_str = "\n"
+  fro_list = []
+
+  # Extract loaded lora dim and alpha
+  for key, value in lora_sd.items():
+    if network_alpha is None and 'alpha' in key:
+      network_alpha = value
+    if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
+      network_dim = value.size()[0]
+    if network_alpha is not None and network_dim is not None:
+      break
+    if network_alpha is None:
+      network_alpha = network_dim
+
+  scale = network_alpha/network_dim
+
+  if dynamic_method:
+    print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
+
+  lora_down_weight = None
+  lora_up_weight = None
+
+  o_lora_sd = lora_sd.copy()
+  block_down_name = None
+  block_up_name = None
+
+  with torch.no_grad():
+    for key, value in tqdm(lora_sd.items()):
+      weight_name = None
+      if 'lora_down' in key:
+        block_down_name = key.rsplit('.lora_down', 1)[0]
+        weight_name = key.rsplit(".", 1)[-1]
+        lora_down_weight = value
+      else:
+        continue
+
+      # find corresponding lora_up and alpha
+      block_up_name = block_down_name
+      lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
+      lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
+
+      weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
+
+      if weights_loaded:
+
+        conv2d = (len(lora_down_weight.size()) == 4)
+        if lora_alpha is None:
+          scale = 1.0
+        else:
+          scale = lora_alpha/lora_down_weight.size()[0]
+
+        if conv2d:
+          full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
+          param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
+        else:
+          full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
+          param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
+
+        if verbose:
+          max_ratio = param_dict['max_ratio']
+          sum_retained = param_dict['sum_retained']
+          fro_retained = param_dict['fro_retained']
+          if not np.isnan(fro_retained):
+            fro_list.append(float(fro_retained))
+
+          verbose_str+=f"{block_down_name:75} | "
+          verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
+
+        if verbose and dynamic_method:
+          verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
+        else:
+          verbose_str+=f"\n"
+
+        new_alpha = param_dict['new_alpha']
+        o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
+        o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
+        o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
+
+        block_down_name = None
+        block_up_name = None
+        lora_down_weight = None
+        lora_up_weight = None
+        weights_loaded = False
+        del param_dict
+
+  if verbose:
+    print(verbose_str)
+
+    print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
+  print("resizing complete")
+  return o_lora_sd, network_dim, new_alpha
+
+
+def resize(args):
+  if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
+    raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
+
+    
+  def str_to_dtype(p):
+    if p == 'float':
+      return torch.float
+    if p == 'fp16':
+      return torch.float16
+    if p == 'bf16':
+      return torch.bfloat16
+    return None
+
+  if args.dynamic_method and not args.dynamic_param:
+    raise Exception("If using dynamic_method, then dynamic_param is required")
+
+  merge_dtype = str_to_dtype('float')  # matmul method above only seems to work in float32
+  save_dtype = str_to_dtype(args.save_precision)
+  if save_dtype is None:
+    save_dtype = merge_dtype
+
+  print("loading Model...")
+  lora_sd, metadata = load_state_dict(args.model, merge_dtype)
+
+  print("Resizing Lora...")
+  state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
+
+  # update metadata
+  if metadata is None:
+    metadata = {}
+
+  comment = metadata.get("ss_training_comment", "")
+
+  if not args.dynamic_method:
+    metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
+    metadata["ss_network_dim"] = str(args.new_rank)
+    metadata["ss_network_alpha"] = str(new_alpha)
+  else:
+    metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
+    metadata["ss_network_dim"] = 'Dynamic'
+    metadata["ss_network_alpha"] = 'Dynamic'
+
+  model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+  metadata["sshs_model_hash"] = model_hash
+  metadata["sshs_legacy_hash"] = legacy_hash
+
+  print(f"saving model to: {args.save_to}")
+  save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+  parser = argparse.ArgumentParser()
+
+  parser.add_argument("--save_precision", type=str, default=None,
+                      choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
+  parser.add_argument("--new_rank", type=int, default=4,
+                      help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
+  parser.add_argument("--save_to", type=str, default=None,
+                      help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
+  parser.add_argument("--model", type=str, default=None,
+                      help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
+  parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
+  parser.add_argument("--verbose", action="store_true", 
+                      help="Display verbose resizing information / rank変更時の詳細情報を出力する")
+  parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
+                      help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
+  parser.add_argument("--dynamic_param", type=float, default=None,
+                      help="Specify target for dynamic reduction")
+       
+  return parser
+
+
+if __name__ == '__main__':
+  parser = setup_parser()
+
+  args = parser.parse_args()
+  resize(args)
diff --git a/external/llite/networks/sdxl_merge_lora.py b/external/llite/networks/sdxl_merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..c513eb59f585bd6a335f8e3eea32efc827300d2b
--- /dev/null
+++ b/external/llite/networks/sdxl_merge_lora.py
@@ -0,0 +1,348 @@
+import math
+import argparse
+import os
+import time
+import torch
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+from library import sai_model_spec, sdxl_model_util, train_util
+import library.model_util as model_util
+import lora
+
+
+def load_state_dict(file_name, dtype):
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        sd = load_file(file_name)
+        metadata = train_util.load_metadata_from_safetensors(file_name)
+    else:
+        sd = torch.load(file_name, map_location="cpu")
+        metadata = {}
+
+    for key in list(sd.keys()):
+        if type(sd[key]) == torch.Tensor:
+            sd[key] = sd[key].to(dtype)
+
+    return sd, metadata
+
+
+def save_to_file(file_name, model, state_dict, dtype, metadata):
+    if dtype is not None:
+        for key in list(state_dict.keys()):
+            if type(state_dict[key]) == torch.Tensor:
+                state_dict[key] = state_dict[key].to(dtype)
+
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        save_file(model, file_name, metadata=metadata)
+    else:
+        torch.save(model, file_name)
+
+
+def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
+    text_encoder1.to(merge_dtype)
+    text_encoder1.to(merge_dtype)
+    unet.to(merge_dtype)
+
+    # create module map
+    name_to_module = {}
+    for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
+        if i <= 1:
+            if i == 0:
+                prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
+            else:
+                prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
+            target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
+        else:
+            prefix = lora.LoRANetwork.LORA_PREFIX_UNET
+            target_replace_modules = (
+                lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
+            )
+
+        for name, module in root_module.named_modules():
+            if module.__class__.__name__ in target_replace_modules:
+                for child_name, child_module in module.named_modules():
+                    if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
+                        lora_name = prefix + "." + name + "." + child_name
+                        lora_name = lora_name.replace(".", "_")
+                        name_to_module[lora_name] = child_module
+
+    for model, ratio in zip(models, ratios):
+        print(f"loading: {model}")
+        lora_sd, _ = load_state_dict(model, merge_dtype)
+
+        print(f"merging...")
+        for key in tqdm(lora_sd.keys()):
+            if "lora_down" in key:
+                up_key = key.replace("lora_down", "lora_up")
+                alpha_key = key[: key.index("lora_down")] + "alpha"
+
+                # find original module for this lora
+                module_name = ".".join(key.split(".")[:-2])  # remove trailing ".lora_down.weight"
+                if module_name not in name_to_module:
+                    print(f"no module found for LoRA weight: {key}")
+                    continue
+                module = name_to_module[module_name]
+                # print(f"apply {key} to {module}")
+
+                down_weight = lora_sd[key]
+                up_weight = lora_sd[up_key]
+
+                dim = down_weight.size()[0]
+                alpha = lora_sd.get(alpha_key, dim)
+                scale = alpha / dim
+
+                # W <- W + U * D
+                weight = module.weight
+                # print(module_name, down_weight.size(), up_weight.size())
+                if len(weight.size()) == 2:
+                    # linear
+                    weight = weight + ratio * (up_weight @ down_weight) * scale
+                elif down_weight.size()[2:4] == (1, 1):
+                    # conv2d 1x1
+                    weight = (
+                        weight
+                        + ratio
+                        * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                        * scale
+                    )
+                else:
+                    # conv2d 3x3
+                    conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+                    # print(conved.size(), weight.size(), module.stride, module.padding)
+                    weight = weight + ratio * conved * scale
+
+                module.weight = torch.nn.Parameter(weight)
+
+
+def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
+    base_alphas = {}  # alpha for merged model
+    base_dims = {}
+
+    merged_sd = {}
+    v2 = None
+    base_model = None
+    for model, ratio in zip(models, ratios):
+        print(f"loading: {model}")
+        lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
+
+        if lora_metadata is not None:
+            if v2 is None:
+                v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None)  # returns string, SDXLはv2がないのでFalseのはず
+            if base_model is None:
+                base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
+
+        # get alpha and dim
+        alphas = {}  # alpha for current model
+        dims = {}  # dims for current model
+        for key in lora_sd.keys():
+            if "alpha" in key:
+                lora_module_name = key[: key.rfind(".alpha")]
+                alpha = float(lora_sd[key].detach().numpy())
+                alphas[lora_module_name] = alpha
+                if lora_module_name not in base_alphas:
+                    base_alphas[lora_module_name] = alpha
+            elif "lora_down" in key:
+                lora_module_name = key[: key.rfind(".lora_down")]
+                dim = lora_sd[key].size()[0]
+                dims[lora_module_name] = dim
+                if lora_module_name not in base_dims:
+                    base_dims[lora_module_name] = dim
+
+        for lora_module_name in dims.keys():
+            if lora_module_name not in alphas:
+                alpha = dims[lora_module_name]
+                alphas[lora_module_name] = alpha
+                if lora_module_name not in base_alphas:
+                    base_alphas[lora_module_name] = alpha
+
+        print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
+
+        # merge
+        print(f"merging...")
+        for key in tqdm(lora_sd.keys()):
+            if "alpha" in key:
+                continue
+            
+            if "lora_up" in key and concat:
+                concat_dim = 1
+            elif "lora_down" in key and concat:
+                concat_dim = 0
+            else:
+                concat_dim = None
+
+            lora_module_name = key[: key.rfind(".lora_")]
+
+            base_alpha = base_alphas[lora_module_name]
+            alpha = alphas[lora_module_name]
+
+            scale = math.sqrt(alpha / base_alpha) * ratio
+            scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
+            
+            if key in merged_sd:
+                assert (
+                    merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
+                ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
+                if concat_dim is not None:
+                    merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
+                else:
+                    merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
+            else:
+                merged_sd[key] = lora_sd[key] * scale
+
+    # set alpha to sd
+    for lora_module_name, alpha in base_alphas.items():
+        key = lora_module_name + ".alpha"
+        merged_sd[key] = torch.tensor(alpha)
+        if shuffle:
+            key_down = lora_module_name + ".lora_down.weight"
+            key_up = lora_module_name + ".lora_up.weight"
+            dim = merged_sd[key_down].shape[0]
+            perm = torch.randperm(dim)
+            merged_sd[key_down] = merged_sd[key_down][perm]
+            merged_sd[key_up] = merged_sd[key_up][:,perm]
+
+    print("merged model")
+    print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
+
+    # check all dims are same
+    dims_list = list(set(base_dims.values()))
+    alphas_list = list(set(base_alphas.values()))
+    all_same_dims = True
+    all_same_alphas = True
+    for dims in dims_list:
+        if dims != dims_list[0]:
+            all_same_dims = False
+            break
+    for alphas in alphas_list:
+        if alphas != alphas_list[0]:
+            all_same_alphas = False
+            break
+
+    # build minimum metadata
+    dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
+    alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
+    metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
+
+    return merged_sd, metadata
+
+
+def merge(args):
+    assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
+
+    def str_to_dtype(p):
+        if p == "float":
+            return torch.float
+        if p == "fp16":
+            return torch.float16
+        if p == "bf16":
+            return torch.bfloat16
+        return None
+
+    merge_dtype = str_to_dtype(args.precision)
+    save_dtype = str_to_dtype(args.save_precision)
+    if save_dtype is None:
+        save_dtype = merge_dtype
+
+    if args.sd_model is not None:
+        print(f"loading SD model: {args.sd_model}")
+
+        (
+            text_model1,
+            text_model2,
+            vae,
+            unet,
+            logit_scale,
+            ckpt_info,
+        ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
+
+        merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
+
+        if args.no_metadata:
+            sai_metadata = None
+        else:
+            merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
+            title = os.path.splitext(os.path.basename(args.save_to))[0]
+            sai_metadata = sai_model_spec.build_metadata(
+                None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
+            )
+
+        print(f"saving SD model to: {args.save_to}")
+        sdxl_model_util.save_stable_diffusion_checkpoint(
+            args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
+        )
+    else:
+        state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
+
+        print(f"calculating hashes and creating metadata...")
+
+        model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+        metadata["sshs_model_hash"] = model_hash
+        metadata["sshs_legacy_hash"] = legacy_hash
+
+        if not args.no_metadata:
+            merged_from = sai_model_spec.build_merged_from(args.models)
+            title = os.path.splitext(os.path.basename(args.save_to))[0]
+            sai_metadata = sai_model_spec.build_metadata(
+                state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
+            )
+            metadata.update(sai_metadata)
+
+        print(f"saving model to: {args.save_to}")
+        save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--save_precision",
+        type=str,
+        default=None,
+        choices=[None, "float", "fp16", "bf16"],
+        help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
+    )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        default="float",
+        choices=["float", "fp16", "bf16"],
+        help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
+    )
+    parser.add_argument(
+        "--sd_model",
+        type=str,
+        default=None,
+        help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
+    )
+    parser.add_argument(
+        "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
+    )
+    parser.add_argument(
+        "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
+    )
+    parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
+    parser.add_argument(
+        "--no_metadata",
+        action="store_true",
+        help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+        + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
+    )
+    parser.add_argument(
+        "--concat",
+        action="store_true",
+        help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+        + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
+    )
+    parser.add_argument(
+        "--shuffle",
+        action="store_true",
+        help="shuffle lora weight./ "
+        + "LoRAの重みをシャッフルする",
+    )
+
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    merge(args)
diff --git a/external/llite/networks/svd_merge_lora.py b/external/llite/networks/svd_merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e813b36eb433da71768925add652f11108074d
--- /dev/null
+++ b/external/llite/networks/svd_merge_lora.py
@@ -0,0 +1,260 @@
+import math
+import argparse
+import os
+import time
+import torch
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+from library import sai_model_spec, train_util
+import library.model_util as model_util
+import lora
+
+
+CLAMP_QUANTILE = 0.99
+
+
+def load_state_dict(file_name, dtype):
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        sd = load_file(file_name)
+        metadata = train_util.load_metadata_from_safetensors(file_name)
+    else:
+        sd = torch.load(file_name, map_location="cpu")
+        metadata = {}
+
+    for key in list(sd.keys()):
+        if type(sd[key]) == torch.Tensor:
+            sd[key] = sd[key].to(dtype)
+
+    return sd, metadata
+
+
+def save_to_file(file_name, state_dict, dtype, metadata):
+    if dtype is not None:
+        for key in list(state_dict.keys()):
+            if type(state_dict[key]) == torch.Tensor:
+                state_dict[key] = state_dict[key].to(dtype)
+
+    if os.path.splitext(file_name)[1] == ".safetensors":
+        save_file(state_dict, file_name, metadata=metadata)
+    else:
+        torch.save(state_dict, file_name)
+
+
+def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
+    print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
+    merged_sd = {}
+    v2 = None
+    base_model = None
+    for model, ratio in zip(models, ratios):
+        print(f"loading: {model}")
+        lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
+
+        if lora_metadata is not None:
+            if v2 is None:
+                v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None)  # return string
+            if base_model is None:
+                base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
+
+        # merge
+        print(f"merging...")
+        for key in tqdm(list(lora_sd.keys())):
+            if "lora_down" not in key:
+                continue
+
+            lora_module_name = key[: key.rfind(".lora_down")]
+
+            down_weight = lora_sd[key]
+            network_dim = down_weight.size()[0]
+
+            up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
+            alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
+
+            in_dim = down_weight.size()[1]
+            out_dim = up_weight.size()[0]
+            conv2d = len(down_weight.size()) == 4
+            kernel_size = None if not conv2d else down_weight.size()[2:4]
+            # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
+
+            # make original weight if not exist
+            if lora_module_name not in merged_sd:
+                weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
+                if device:
+                    weight = weight.to(device)
+            else:
+                weight = merged_sd[lora_module_name]
+
+            # merge to weight
+            if device:
+                up_weight = up_weight.to(device)
+                down_weight = down_weight.to(device)
+
+            # W <- W + U * D
+            scale = alpha / network_dim
+
+            if device:  # and isinstance(scale, torch.Tensor):
+                scale = scale.to(device)
+
+            if not conv2d:  # linear
+                weight = weight + ratio * (up_weight @ down_weight) * scale
+            elif kernel_size == (1, 1):
+                weight = (
+                    weight
+                    + ratio
+                    * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+                    * scale
+                )
+            else:
+                conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+                weight = weight + ratio * conved * scale
+
+            merged_sd[lora_module_name] = weight
+
+    # extract from merged weights
+    print("extract new lora...")
+    merged_lora_sd = {}
+    with torch.no_grad():
+        for lora_module_name, mat in tqdm(list(merged_sd.items())):
+            conv2d = len(mat.size()) == 4
+            kernel_size = None if not conv2d else mat.size()[2:4]
+            conv2d_3x3 = conv2d and kernel_size != (1, 1)
+            out_dim, in_dim = mat.size()[0:2]
+
+            if conv2d:
+                if conv2d_3x3:
+                    mat = mat.flatten(start_dim=1)
+                else:
+                    mat = mat.squeeze()
+
+            module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
+            module_new_rank = min(module_new_rank, in_dim, out_dim)  # LoRA rank cannot exceed the original dim
+
+            U, S, Vh = torch.linalg.svd(mat)
+
+            U = U[:, :module_new_rank]
+            S = S[:module_new_rank]
+            U = U @ torch.diag(S)
+
+            Vh = Vh[:module_new_rank, :]
+
+            dist = torch.cat([U.flatten(), Vh.flatten()])
+            hi_val = torch.quantile(dist, CLAMP_QUANTILE)
+            low_val = -hi_val
+
+            U = U.clamp(low_val, hi_val)
+            Vh = Vh.clamp(low_val, hi_val)
+
+            if conv2d:
+                U = U.reshape(out_dim, module_new_rank, 1, 1)
+                Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
+
+            up_weight = U
+            down_weight = Vh
+
+            merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
+            merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
+            merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
+
+    # build minimum metadata
+    dims = f"{new_rank}"
+    alphas = f"{new_rank}"
+    if new_conv_rank is not None:
+        network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
+    else:
+        network_args = None
+    metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
+
+    return merged_lora_sd, metadata, v2 == "True", base_model
+
+
+def merge(args):
+    assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
+
+    def str_to_dtype(p):
+        if p == "float":
+            return torch.float
+        if p == "fp16":
+            return torch.float16
+        if p == "bf16":
+            return torch.bfloat16
+        return None
+
+    merge_dtype = str_to_dtype(args.precision)
+    save_dtype = str_to_dtype(args.save_precision)
+    if save_dtype is None:
+        save_dtype = merge_dtype
+
+    new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
+    state_dict, metadata, v2, base_model = merge_lora_models(
+        args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
+    )
+
+    print(f"calculating hashes and creating metadata...")
+
+    model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+    metadata["sshs_model_hash"] = model_hash
+    metadata["sshs_legacy_hash"] = legacy_hash
+
+    if not args.no_metadata:
+        is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
+        merged_from = sai_model_spec.build_merged_from(args.models)
+        title = os.path.splitext(os.path.basename(args.save_to))[0]
+        sai_metadata = sai_model_spec.build_metadata(
+            state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
+        )
+        if v2:
+            # TODO read sai modelspec
+            print(
+                "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
+            )
+        metadata.update(sai_metadata)
+
+    print(f"saving model to: {args.save_to}")
+    save_to_file(args.save_to, state_dict, save_dtype, metadata)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--save_precision",
+        type=str,
+        default=None,
+        choices=[None, "float", "fp16", "bf16"],
+        help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
+    )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        default="float",
+        choices=["float", "fp16", "bf16"],
+        help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
+    )
+    parser.add_argument(
+        "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
+    )
+    parser.add_argument(
+        "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
+    )
+    parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
+    parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
+    parser.add_argument(
+        "--new_conv_rank",
+        type=int,
+        default=None,
+        help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
+    )
+    parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
+    parser.add_argument(
+        "--no_metadata",
+        action="store_true",
+        help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+        + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
+    )
+
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    merge(args)
diff --git a/external/llite/tools/cache_latents.py b/external/llite/tools/cache_latents.py
new file mode 100644
index 0000000000000000000000000000000000000000..17916ef707dbcb992add8826c3e3b4fd808cf18b
--- /dev/null
+++ b/external/llite/tools/cache_latents.py
@@ -0,0 +1,194 @@
+# latentsのdiskへの事前キャッシュを行う / cache latents to disk
+
+import argparse
+import math
+from multiprocessing import Value
+import os
+
+from accelerate.utils import set_seed
+import torch
+from tqdm import tqdm
+
+from library import config_util
+from library import train_util
+from library import sdxl_train_util
+from library.config_util import (
+    ConfigSanitizer,
+    BlueprintGenerator,
+)
+
+
+def cache_to_disk(args: argparse.Namespace) -> None:
+    train_util.prepare_dataset_args(args, True)
+
+    # check cache latents arg
+    assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
+
+    use_dreambooth_method = args.in_json is None
+
+    if args.seed is not None:
+        set_seed(args.seed)  # 乱数系列を初期化する
+
+    # tokenizerを準備する:datasetを動かすために必要
+    if args.sdxl:
+        tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
+        tokenizers = [tokenizer1, tokenizer2]
+    else:
+        tokenizer = train_util.load_tokenizer(args)
+        tokenizers = [tokenizer]
+
+    # データセットを準備する
+    if args.dataset_class is None:
+        blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
+        if args.dataset_config is not None:
+            print(f"Load dataset config from {args.dataset_config}")
+            user_config = config_util.load_user_config(args.dataset_config)
+            ignored = ["train_data_dir", "in_json"]
+            if any(getattr(args, attr) is not None for attr in ignored):
+                print(
+                    "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
+                        ", ".join(ignored)
+                    )
+                )
+        else:
+            if use_dreambooth_method:
+                print("Using DreamBooth method.")
+                user_config = {
+                    "datasets": [
+                        {
+                            "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
+                                args.train_data_dir, args.reg_data_dir
+                            )
+                        }
+                    ]
+                }
+            else:
+                print("Training with captions.")
+                user_config = {
+                    "datasets": [
+                        {
+                            "subsets": [
+                                {
+                                    "image_dir": args.train_data_dir,
+                                    "metadata_file": args.in_json,
+                                }
+                            ]
+                        }
+                    ]
+                }
+
+        blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
+        train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+    else:
+        train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
+
+    # datasetのcache_latentsを呼ばなければ、生の画像が返る
+
+    current_epoch = Value("i", 0)
+    current_step = Value("i", 0)
+    ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
+    collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
+
+    # acceleratorを準備する
+    print("prepare accelerator")
+    accelerator = train_util.prepare_accelerator(args)
+
+    # mixed precisionに対応した型を用意しておき適宜castする
+    weight_dtype, _ = train_util.prepare_dtype(args)
+    vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
+
+    # モデルを読み込む
+    print("load model")
+    if args.sdxl:
+        (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
+    else:
+        _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
+
+    if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
+        vae.set_use_memory_efficient_attention_xformers(args.xformers)
+    vae.to(accelerator.device, dtype=vae_dtype)
+    vae.requires_grad_(False)
+    vae.eval()
+
+    # dataloaderを準備する
+    train_dataset_group.set_caching_mode("latents")
+
+    # DataLoaderのプロセス数:0はメインプロセスになる
+    n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1)  # cpu_count-1 ただし最大で指定された数まで
+
+    train_dataloader = torch.utils.data.DataLoader(
+        train_dataset_group,
+        batch_size=1,
+        shuffle=True,
+        collate_fn=collator,
+        num_workers=n_workers,
+        persistent_workers=args.persistent_data_loader_workers,
+    )
+
+    # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
+    train_dataloader = accelerator.prepare(train_dataloader)
+
+    # データ取得のためのループ
+    for batch in tqdm(train_dataloader):
+        b_size = len(batch["images"])
+        vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
+        flip_aug = batch["flip_aug"]
+        random_crop = batch["random_crop"]
+        bucket_reso = batch["bucket_reso"]
+
+        # バッチを分割して処理する
+        for i in range(0, b_size, vae_batch_size):
+            images = batch["images"][i : i + vae_batch_size]
+            absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
+            resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
+
+            image_infos = []
+            for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
+                image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
+                image_info.image = image
+                image_info.bucket_reso = bucket_reso
+                image_info.resized_size = resized_size
+                image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
+
+                if args.skip_existing:
+                    if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
+                        print(f"Skipping {image_info.latents_npz} because it already exists.")
+                        continue
+
+                image_infos.append(image_info)
+
+            if len(image_infos) > 0:
+                train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
+
+    accelerator.wait_for_everyone()
+    accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+
+    train_util.add_sd_models_arguments(parser)
+    train_util.add_training_arguments(parser, True)
+    train_util.add_dataset_arguments(parser, True, True, True)
+    config_util.add_config_arguments(parser)
+    parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
+    parser.add_argument(
+        "--no_half_vae",
+        action="store_true",
+        help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
+    )
+    parser.add_argument(
+        "--skip_existing",
+        action="store_true",
+        help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
+    )
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    args = train_util.read_config_from_file(args, parser)
+
+    cache_to_disk(args)
diff --git a/external/llite/tools/cache_text_encoder_outputs.py b/external/llite/tools/cache_text_encoder_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9b13d68c35e3f080060f0ad78f854c0c471f6c
--- /dev/null
+++ b/external/llite/tools/cache_text_encoder_outputs.py
@@ -0,0 +1,191 @@
+# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
+
+import argparse
+import math
+from multiprocessing import Value
+import os
+
+from accelerate.utils import set_seed
+import torch
+from tqdm import tqdm
+
+from library import config_util
+from library import train_util
+from library import sdxl_train_util
+from library.config_util import (
+    ConfigSanitizer,
+    BlueprintGenerator,
+)
+
+
+def cache_to_disk(args: argparse.Namespace) -> None:
+    train_util.prepare_dataset_args(args, True)
+
+    # check cache arg
+    assert (
+        args.cache_text_encoder_outputs_to_disk
+    ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
+
+    # できるだけ準備はしておくが今のところSDXLのみしか動かない
+    assert (
+        args.sdxl
+    ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
+
+    use_dreambooth_method = args.in_json is None
+
+    if args.seed is not None:
+        set_seed(args.seed)  # 乱数系列を初期化する
+
+    # tokenizerを準備する:datasetを動かすために必要
+    if args.sdxl:
+        tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
+        tokenizers = [tokenizer1, tokenizer2]
+    else:
+        tokenizer = train_util.load_tokenizer(args)
+        tokenizers = [tokenizer]
+
+    # データセットを準備する
+    if args.dataset_class is None:
+        blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
+        if args.dataset_config is not None:
+            print(f"Load dataset config from {args.dataset_config}")
+            user_config = config_util.load_user_config(args.dataset_config)
+            ignored = ["train_data_dir", "in_json"]
+            if any(getattr(args, attr) is not None for attr in ignored):
+                print(
+                    "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
+                        ", ".join(ignored)
+                    )
+                )
+        else:
+            if use_dreambooth_method:
+                print("Using DreamBooth method.")
+                user_config = {
+                    "datasets": [
+                        {
+                            "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
+                                args.train_data_dir, args.reg_data_dir
+                            )
+                        }
+                    ]
+                }
+            else:
+                print("Training with captions.")
+                user_config = {
+                    "datasets": [
+                        {
+                            "subsets": [
+                                {
+                                    "image_dir": args.train_data_dir,
+                                    "metadata_file": args.in_json,
+                                }
+                            ]
+                        }
+                    ]
+                }
+
+        blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
+        train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+    else:
+        train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
+
+    current_epoch = Value("i", 0)
+    current_step = Value("i", 0)
+    ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
+    collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
+
+    # acceleratorを準備する
+    print("prepare accelerator")
+    accelerator = train_util.prepare_accelerator(args)
+
+    # mixed precisionに対応した型を用意しておき適宜castする
+    weight_dtype, _ = train_util.prepare_dtype(args)
+
+    # モデルを読み込む
+    print("load model")
+    if args.sdxl:
+        (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
+        text_encoders = [text_encoder1, text_encoder2]
+    else:
+        text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
+        text_encoders = [text_encoder1]
+
+    for text_encoder in text_encoders:
+        text_encoder.to(accelerator.device, dtype=weight_dtype)
+        text_encoder.requires_grad_(False)
+        text_encoder.eval()
+
+    # dataloaderを準備する
+    train_dataset_group.set_caching_mode("text")
+
+    # DataLoaderのプロセス数:0はメインプロセスになる
+    n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1)  # cpu_count-1 ただし最大で指定された数まで
+
+    train_dataloader = torch.utils.data.DataLoader(
+        train_dataset_group,
+        batch_size=1,
+        shuffle=True,
+        collate_fn=collator,
+        num_workers=n_workers,
+        persistent_workers=args.persistent_data_loader_workers,
+    )
+
+    # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
+    train_dataloader = accelerator.prepare(train_dataloader)
+
+    # データ取得のためのループ
+    for batch in tqdm(train_dataloader):
+        absolute_paths = batch["absolute_paths"]
+        input_ids1_list = batch["input_ids1_list"]
+        input_ids2_list = batch["input_ids2_list"]
+
+        image_infos = []
+        for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
+            image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
+            image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
+            image_info
+
+            if args.skip_existing:
+                if os.path.exists(image_info.text_encoder_outputs_npz):
+                    print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
+                    continue
+                
+            image_info.input_ids1 = input_ids1
+            image_info.input_ids2 = input_ids2
+            image_infos.append(image_info)
+
+        if len(image_infos) > 0:
+            b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
+            b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
+            train_util.cache_batch_text_encoder_outputs(
+                image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
+            )
+
+    accelerator.wait_for_everyone()
+    accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+
+    train_util.add_sd_models_arguments(parser)
+    train_util.add_training_arguments(parser, True)
+    train_util.add_dataset_arguments(parser, True, True, True)
+    config_util.add_config_arguments(parser)
+    sdxl_train_util.add_sdxl_training_arguments(parser)
+    parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
+    parser.add_argument(
+        "--skip_existing",
+        action="store_true",
+        help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
+    )
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    args = train_util.read_config_from_file(args, parser)
+
+    cache_to_disk(args)
diff --git a/external/llite/tools/canny.py b/external/llite/tools/canny.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e0806898786e5251d2e715e33896bb4958a35e8
--- /dev/null
+++ b/external/llite/tools/canny.py
@@ -0,0 +1,30 @@
+import argparse
+import cv2
+
+
+def canny(args):
+  img = cv2.imread(args.input)
+  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+  canny_img = cv2.Canny(img, args.thres1, args.thres2)
+  # canny_img = 255 - canny_img
+
+  cv2.imwrite(args.output, canny_img)
+  print("done!")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--input", type=str, default=None, help="input path")
+  parser.add_argument("--output", type=str, default=None, help="output path")
+  parser.add_argument("--thres1", type=int, default=32, help="thres1")
+  parser.add_argument("--thres2", type=int, default=224, help="thres2")
+
+  return parser
+
+
+if __name__ == '__main__':
+  parser = setup_parser()
+
+  args = parser.parse_args()
+  canny(args)
diff --git a/external/llite/tools/convert_diffusers20_original_sd.py b/external/llite/tools/convert_diffusers20_original_sd.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe30996aa509bf93581d528028fca18aa855aa77
--- /dev/null
+++ b/external/llite/tools/convert_diffusers20_original_sd.py
@@ -0,0 +1,160 @@
+# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
+
+import argparse
+import os
+import torch
+from diffusers import StableDiffusionPipeline
+
+import library.model_util as model_util
+
+
+def convert(args):
+    # 引数を確認する
+    load_dtype = torch.float16 if args.fp16 else None
+
+    save_dtype = None
+    if args.fp16 or args.save_precision_as == "fp16":
+        save_dtype = torch.float16
+    elif args.bf16 or args.save_precision_as == "bf16":
+        save_dtype = torch.bfloat16
+    elif args.float or args.save_precision_as == "float":
+        save_dtype = torch.float
+
+    is_load_ckpt = os.path.isfile(args.model_to_load)
+    is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
+
+    assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
+    # assert (
+    #     is_save_ckpt or args.reference_model is not None
+    # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
+
+    # モデルを読み込む
+    msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
+    print(f"loading {msg}: {args.model_to_load}")
+
+    if is_load_ckpt:
+        v2_model = args.v2
+        text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
+            v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
+        )
+    else:
+        pipe = StableDiffusionPipeline.from_pretrained(
+            args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
+        )
+        text_encoder = pipe.text_encoder
+        vae = pipe.vae
+        unet = pipe.unet
+
+        if args.v1 == args.v2:
+            # 自動判定する
+            v2_model = unet.config.cross_attention_dim == 1024
+            print("checking model version: model is " + ("v2" if v2_model else "v1"))
+        else:
+            v2_model = not args.v1
+
+    # 変換して保存する
+    msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
+    print(f"converting and saving as {msg}: {args.model_to_save}")
+
+    if is_save_ckpt:
+        original_model = args.model_to_load if is_load_ckpt else None
+        key_count = model_util.save_stable_diffusion_checkpoint(
+            v2_model,
+            args.model_to_save,
+            text_encoder,
+            unet,
+            original_model,
+            args.epoch,
+            args.global_step,
+            None if args.metadata is None else eval(args.metadata),
+            save_dtype=save_dtype,
+            vae=vae,
+        )
+        print(f"model saved. total converted state_dict keys: {key_count}")
+    else:
+        print(
+            f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
+        )
+        model_util.save_diffusers_checkpoint(
+            v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
+        )
+        print("model saved.")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
+    )
+    parser.add_argument(
+        "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
+    )
+    parser.add_argument(
+        "--unet_use_linear_projection",
+        action="store_true",
+        help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
+    )
+    parser.add_argument(
+        "--fp16",
+        action="store_true",
+        help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
+    )
+    parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
+    parser.add_argument(
+        "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
+    )
+    parser.add_argument(
+        "--save_precision_as",
+        type=str,
+        default="no",
+        choices=["fp16", "bf16", "float"],
+        help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
+    )
+    parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
+    parser.add_argument(
+        "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
+    )
+    parser.add_argument(
+        "--metadata",
+        type=str,
+        default=None,
+        help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
+    )
+    parser.add_argument(
+        "--variant",
+        type=str,
+        default=None,
+        help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
+    )
+    parser.add_argument(
+        "--reference_model",
+        type=str,
+        default=None,
+        help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
+    )
+    parser.add_argument(
+        "--use_safetensors",
+        action="store_true",
+        help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
+    )
+
+    parser.add_argument(
+        "model_to_load",
+        type=str,
+        default=None,
+        help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
+    )
+    parser.add_argument(
+        "model_to_save",
+        type=str,
+        default=None,
+        help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
+    )
+    return parser
+
+
+if __name__ == "__main__":
+    parser = setup_parser()
+
+    args = parser.parse_args()
+    convert(args)
diff --git a/external/llite/tools/detect_face_rotate.py b/external/llite/tools/detect_face_rotate.py
new file mode 100644
index 0000000000000000000000000000000000000000..68dec6cae932e827e79c49992238be7fd2edf21c
--- /dev/null
+++ b/external/llite/tools/detect_face_rotate.py
@@ -0,0 +1,246 @@
+# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
+
+# v2: extract max face if multiple faces are found
+# v3: add crop_ratio option
+# v4: add multiple faces extraction and min/max size
+
+import argparse
+import math
+import cv2
+import glob
+import os
+from anime_face_detector import create_detector
+from tqdm import tqdm
+import numpy as np
+
+KP_REYE = 11
+KP_LEYE = 19
+
+SCORE_THRES = 0.90
+
+
+def detect_faces(detector, image, min_size):
+  preds = detector(image)                     # bgr
+  # print(len(preds))
+
+  faces = []
+  for pred in preds:
+    bb = pred['bbox']
+    score = bb[-1]
+    if score < SCORE_THRES:
+      continue
+
+    left, top, right, bottom = bb[:4]
+    cx = int((left + right) / 2)
+    cy = int((top + bottom) / 2)
+    fw = int(right - left)
+    fh = int(bottom - top)
+
+    lex, ley = pred['keypoints'][KP_LEYE, 0:2]
+    rex, rey = pred['keypoints'][KP_REYE, 0:2]
+    angle = math.atan2(ley - rey, lex - rex)
+    angle = angle / math.pi * 180
+
+    faces.append((cx, cy, fw, fh, angle))
+
+  faces.sort(key=lambda x: max(x[2], x[3]), reverse=True)         # 大きい順
+  return faces
+
+
+def rotate_image(image, angle, cx, cy):
+  h, w = image.shape[0:2]
+  rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
+
+  # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
+  # nh = max(h, int(w * math.sin(angle)))
+  # nw = max(w, int(h * math.sin(angle)))
+  # if nh > h or nw > w:
+  #   pad_y = nh - h
+  #   pad_t = pad_y // 2
+  #   pad_x = nw - w
+  #   pad_l = pad_x // 2
+  #   m = np.array([[0, 0, pad_l],
+  #                 [0, 0, pad_t]])
+  #   rot_mat = rot_mat + m
+  #   h, w = nh, nw
+  #   cx += pad_l
+  #   cy += pad_t
+
+  result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
+  return result, cx, cy
+
+
+def process(args):
+  assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
+  assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
+
+  # アニメ顔検出モデルを読み込む
+  print("loading face detector.")
+  detector = create_detector('yolov3')
+
+  # cropの引数を解析する
+  if args.crop_size is None:
+    crop_width = crop_height = None
+  else:
+    tokens = args.crop_size.split(',')
+    assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
+    crop_width, crop_height = [int(t) for t in tokens]
+
+  if args.crop_ratio is None:
+    crop_h_ratio = crop_v_ratio = None
+  else:
+    tokens = args.crop_ratio.split(',')
+    assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
+    crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
+
+  # 画像を処理する
+  print("processing.")
+  output_extension = ".png"
+
+  os.makedirs(args.dst_dir, exist_ok=True)
+  paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
+      glob.glob(os.path.join(args.src_dir, "*.webp"))
+  for path in tqdm(paths):
+    basename = os.path.splitext(os.path.basename(path))[0]
+
+    # image = cv2.imread(path)        # 日本語ファイル名でエラーになる
+    image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
+    if len(image.shape) == 2:
+      image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+    if image.shape[2] == 4:
+      print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
+      image = image[:, :, :3].copy()                    # copyをしないと内部的に透明度情報が付いたままになるらしい
+
+    h, w = image.shape[:2]
+
+    faces = detect_faces(detector, image, args.multiple_faces)
+    for i, face in enumerate(faces):
+      cx, cy, fw, fh, angle = face
+      face_size = max(fw, fh)
+      if args.min_size is not None and face_size < args.min_size:
+        continue
+      if args.max_size is not None and face_size >= args.max_size:
+        continue
+      face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
+
+      # オプション指定があれば回転する
+      face_img = image
+      if args.rotate:
+        face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
+
+      # オプション指定があれば顔を中心に切り出す
+      if crop_width is not None or crop_h_ratio is not None:
+        cur_crop_width, cur_crop_height = crop_width, crop_height
+        if crop_h_ratio is not None:
+          cur_crop_width = int(face_size * crop_h_ratio + .5)
+          cur_crop_height = int(face_size * crop_v_ratio + .5)
+
+        # リサイズを必要なら行う
+        scale = 1.0
+        if args.resize_face_size is not None:
+          # 顔サイズを基準にリサイズする
+          scale = args.resize_face_size / face_size
+          if scale < cur_crop_width / w:
+            print(
+                f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
+            scale = cur_crop_width / w
+          if scale < cur_crop_height / h:
+            print(
+                f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
+            scale = cur_crop_height / h
+        elif crop_h_ratio is not None:
+          # 倍率指定の時にはリサイズしない
+          pass
+        else:
+          # 切り出しサイズ指定あり
+          if w < cur_crop_width:
+            print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
+            scale = cur_crop_width / w
+          if h < cur_crop_height:
+            print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
+            scale = cur_crop_height / h
+          if args.resize_fit:
+            scale = max(cur_crop_width / w, cur_crop_height / h)
+
+        if scale != 1.0:
+          w = int(w * scale + .5)
+          h = int(h * scale + .5)
+          face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
+          cx = int(cx * scale + .5)
+          cy = int(cy * scale + .5)
+          fw = int(fw * scale + .5)
+          fh = int(fh * scale + .5)
+
+        cur_crop_width = min(cur_crop_width, face_img.shape[1])
+        cur_crop_height = min(cur_crop_height, face_img.shape[0])
+
+        x = cx - cur_crop_width // 2
+        cx = cur_crop_width // 2
+        if x < 0:
+          cx = cx + x
+          x = 0
+        elif x + cur_crop_width > w:
+          cx = cx + (x + cur_crop_width - w)
+          x = w - cur_crop_width
+        face_img = face_img[:, x:x+cur_crop_width]
+
+        y = cy - cur_crop_height // 2
+        cy = cur_crop_height // 2
+        if y < 0:
+          cy = cy + y
+          y = 0
+        elif y + cur_crop_height > h:
+          cy = cy + (y + cur_crop_height - h)
+          y = h - cur_crop_height
+        face_img = face_img[y:y + cur_crop_height]
+
+      # # debug
+      # print(path, cx, cy, angle)
+      # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
+      # cv2.imshow("image", crp)
+      # if cv2.waitKey() == 27:
+      #   break
+      # cv2.destroyAllWindows()
+
+      # debug
+      if args.debug:
+        cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
+
+      _, buf = cv2.imencode(output_extension, face_img)
+      with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
+        buf.tofile(f)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
+  parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
+  parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
+  parser.add_argument("--resize_fit", action="store_true",
+                      help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
+  parser.add_argument("--resize_face_size", type=int, default=None,
+                      help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
+  parser.add_argument("--crop_size", type=str, default=None,
+                      help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
+  parser.add_argument("--crop_ratio", type=str, default=None,
+                      help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
+  parser.add_argument("--min_size", type=int, default=None,
+                      help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
+  parser.add_argument("--max_size", type=int, default=None,
+                      help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
+  parser.add_argument("--multiple_faces", action="store_true",
+                      help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
+  parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
+
+  return parser
+
+
+if __name__ == '__main__':
+  parser = setup_parser()
+
+  args = parser.parse_args()
+
+  process(args)
diff --git a/external/llite/tools/latent_upscaler.py b/external/llite/tools/latent_upscaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab1fa3390fa781c24fe796f1ff0987c91dcc3f95
--- /dev/null
+++ b/external/llite/tools/latent_upscaler.py
@@ -0,0 +1,348 @@
+# 外部から簡単にupscalerを呼ぶためのスクリプト
+# 単体で動くようにモデル定義も含めている
+
+import argparse
+import glob
+import os
+import cv2
+from diffusers import AutoencoderKL
+
+from typing import Dict, List
+import numpy as np
+
+import torch
+from torch import nn
+from tqdm import tqdm
+from PIL import Image
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
+        super(ResidualBlock, self).__init__()
+
+        if out_channels is None:
+            out_channels = in_channels
+
+        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+
+        self.relu2 = nn.ReLU(inplace=True)  # このReLUはresidualに足す前にかけるほうがいいかも
+
+        # initialize weights
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        out += residual
+
+        out = self.relu2(out)
+
+        return out
+
+
+class Upscaler(nn.Module):
+    def __init__(self):
+        super(Upscaler, self).__init__()
+
+        # define layers
+        # latent has 4 channels
+
+        self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+        self.bn1 = nn.BatchNorm2d(128)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        # resblocks
+        # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
+        self.resblock1 = ResidualBlock(128)
+        self.resblock2 = ResidualBlock(128)
+        self.resblock3 = ResidualBlock(128)
+        self.resblock4 = ResidualBlock(128)
+        self.resblock5 = ResidualBlock(128)
+        self.resblock6 = ResidualBlock(128)
+        self.resblock7 = ResidualBlock(128)
+        self.resblock8 = ResidualBlock(128)
+        self.resblock9 = ResidualBlock(128)
+        self.resblock10 = ResidualBlock(128)
+        self.resblock11 = ResidualBlock(128)
+        self.resblock12 = ResidualBlock(128)
+        self.resblock13 = ResidualBlock(128)
+        self.resblock14 = ResidualBlock(128)
+        self.resblock15 = ResidualBlock(128)
+        self.resblock16 = ResidualBlock(128)
+        self.resblock17 = ResidualBlock(128)
+        self.resblock18 = ResidualBlock(128)
+        self.resblock19 = ResidualBlock(128)
+        self.resblock20 = ResidualBlock(128)
+
+        # last convs
+        self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+        self.bn2 = nn.BatchNorm2d(64)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+        self.bn3 = nn.BatchNorm2d(64)
+        self.relu3 = nn.ReLU(inplace=True)
+
+        # final conv: output 4 channels
+        self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
+
+        # initialize weights
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+        # initialize final conv weights to 0: 流行りのzero conv
+        nn.init.constant_(self.conv_final.weight, 0)
+
+    def forward(self, x):
+        inp = x
+
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu1(x)
+
+        # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
+        residual = x
+        x = self.resblock1(x)
+        x = self.resblock2(x)
+        x = self.resblock3(x)
+        x = self.resblock4(x)
+        x = x + residual
+        residual = x
+        x = self.resblock5(x)
+        x = self.resblock6(x)
+        x = self.resblock7(x)
+        x = self.resblock8(x)
+        x = x + residual
+        residual = x
+        x = self.resblock9(x)
+        x = self.resblock10(x)
+        x = self.resblock11(x)
+        x = self.resblock12(x)
+        x = x + residual
+        residual = x
+        x = self.resblock13(x)
+        x = self.resblock14(x)
+        x = self.resblock15(x)
+        x = self.resblock16(x)
+        x = x + residual
+        residual = x
+        x = self.resblock17(x)
+        x = self.resblock18(x)
+        x = self.resblock19(x)
+        x = self.resblock20(x)
+        x = x + residual
+
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu2(x)
+        x = self.conv3(x)
+        x = self.bn3(x)
+
+        # ここにreluを入れないほうがいい気がする
+
+        x = self.conv_final(x)
+
+        # network estimates the difference between the input and the output
+        x = x + inp
+
+        return x
+
+    def support_latents(self) -> bool:
+        return False
+
+    def upscale(
+        self,
+        vae: AutoencoderKL,
+        lowreso_images: List[Image.Image],
+        lowreso_latents: torch.Tensor,
+        dtype: torch.dtype,
+        width: int,
+        height: int,
+        batch_size: int = 1,
+        vae_batch_size: int = 1,
+    ):
+        # assertion
+        assert lowreso_images is not None, "Upscaler requires lowreso image"
+
+        # make upsampled image with lanczos4
+        upsampled_images = []
+        for lowreso_image in lowreso_images:
+            upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
+            upsampled_images.append(upsampled_image)
+
+        # convert to tensor: this tensor is too large to be converted to cuda
+        upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
+        upsampled_images = torch.stack(upsampled_images, dim=0)
+        upsampled_images = upsampled_images.to(dtype)
+
+        # normalize to [-1, 1]
+        upsampled_images = upsampled_images / 127.5 - 1.0
+
+        # convert upsample images to latents with batch size
+        # print("Encoding upsampled (LANCZOS4) images...")
+        upsampled_latents = []
+        for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
+            batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
+            with torch.no_grad():
+                batch = vae.encode(batch).latent_dist.sample()
+            upsampled_latents.append(batch)
+
+        upsampled_latents = torch.cat(upsampled_latents, dim=0)
+
+        # upscale (refine) latents with this model with batch size
+        print("Upscaling latents...")
+        upscaled_latents = []
+        for i in range(0, upsampled_latents.shape[0], batch_size):
+            with torch.no_grad():
+                upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
+        upscaled_latents = torch.cat(upscaled_latents, dim=0)
+
+        return upscaled_latents * 0.18215
+
+
+# external interface: returns a model
+def create_upscaler(**kwargs):
+    weights = kwargs["weights"]
+    model = Upscaler()
+
+    print(f"Loading weights from {weights}...")
+    if os.path.splitext(weights)[1] == ".safetensors":
+        from safetensors.torch import load_file
+
+        sd = load_file(weights)
+    else:
+        sd = torch.load(weights, map_location=torch.device("cpu"))
+    model.load_state_dict(sd)
+    return model
+
+
+# another interface: upscale images with a model for given images from command line
+def upscale_images(args: argparse.Namespace):
+    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    us_dtype = torch.float16  # TODO: support fp32/bf16
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    # load VAE with Diffusers
+    assert args.vae_path is not None, "VAE path is required"
+    print(f"Loading VAE from {args.vae_path}...")
+    vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
+    vae.to(DEVICE, dtype=us_dtype)
+
+    # prepare model
+    print("Preparing model...")
+    upscaler: Upscaler = create_upscaler(weights=args.weights)
+    # print("Loading weights from", args.weights)
+    # upscaler.load_state_dict(torch.load(args.weights))
+    upscaler.eval()
+    upscaler.to(DEVICE, dtype=us_dtype)
+
+    # load images
+    image_paths = glob.glob(args.image_pattern)
+    images = []
+    for image_path in image_paths:
+        image = Image.open(image_path)
+        image = image.convert("RGB")
+
+        # make divisible by 8
+        width = image.width
+        height = image.height
+        if width % 8 != 0:
+            width = width - (width % 8)
+        if height % 8 != 0:
+            height = height - (height % 8)
+        if width != image.width or height != image.height:
+            image = image.crop((0, 0, width, height))
+
+        images.append(image)
+
+    # debug output
+    if args.debug:
+        for image, image_path in zip(images, image_paths):
+            image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
+
+            basename = os.path.basename(image_path)
+            basename_wo_ext, ext = os.path.splitext(basename)
+            dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
+            image_debug.save(dest_file_name)
+
+    # upscale
+    print("Upscaling...")
+    upscaled_latents = upscaler.upscale(
+        vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
+    )
+    upscaled_latents /= 0.18215
+
+    # decode with batch
+    print("Decoding...")
+    upscaled_images = []
+    for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
+        with torch.no_grad():
+            batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
+        batch = batch.to("cpu")
+        upscaled_images.append(batch)
+    upscaled_images = torch.cat(upscaled_images, dim=0)
+
+    # tensor to numpy
+    upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
+    upscaled_images = (upscaled_images + 1.0) * 127.5
+    upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
+
+    upscaled_images = upscaled_images[..., ::-1]
+
+    # save images
+    for i, image in enumerate(upscaled_images):
+        basename = os.path.basename(image_paths[i])
+        basename_wo_ext, ext = os.path.splitext(basename)
+        dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
+        cv2.imwrite(dest_file_name, image)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
+    parser.add_argument("--weights", type=str, default=None, help="Weights path")
+    parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
+    parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
+    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
+    parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
+    parser.add_argument("--debug", action="store_true", help="Debug mode")
+
+    args = parser.parse_args()
+    upscale_images(args)
diff --git a/external/llite/tools/merge_models.py b/external/llite/tools/merge_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..391bfe677bfd60d396851157cb9f040c7be9b4bb
--- /dev/null
+++ b/external/llite/tools/merge_models.py
@@ -0,0 +1,168 @@
+import argparse
+import os
+
+import torch
+from safetensors import safe_open
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+
+
+def is_unet_key(key):
+    # VAE or TextEncoder, the last one is for SDXL
+    return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key)
+
+
+TEXT_ENCODER_KEY_REPLACEMENTS = [
+    ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
+    ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
+    ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
+]
+
+
+# support for models with different text encoder keys
+def replace_text_encoder_key(key):
+    for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
+        if key.startswith(rep_from):
+            return True, rep_to + key[len(rep_from) :]
+    return False, key
+
+
+def merge(args):
+    if args.precision == "fp16":
+        dtype = torch.float16
+    elif args.precision == "bf16":
+        dtype = torch.bfloat16
+    else:
+        dtype = torch.float
+
+    if args.saving_precision == "fp16":
+        save_dtype = torch.float16
+    elif args.saving_precision == "bf16":
+        save_dtype = torch.bfloat16
+    else:
+        save_dtype = torch.float
+
+    # check if all models are safetensors
+    for model in args.models:
+        if not model.endswith("safetensors"):
+            print(f"Model {model} is not a safetensors model")
+            exit()
+        if not os.path.isfile(model):
+            print(f"Model {model} does not exist")
+            exit()
+
+    assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
+
+    # load and merge
+    ratio = 1.0 / len(args.models)  # default
+    supplementary_key_ratios = {}  # [key] = ratio, for keys not in all models, add later
+
+    merged_sd = None
+    first_model_keys = set()  # check missing keys in other models
+    for i, model in enumerate(args.models):
+        if args.ratios is not None:
+            ratio = args.ratios[i]
+
+        if merged_sd is None:
+            # load first model
+            print(f"Loading model {model}, ratio = {ratio}...")
+            merged_sd = {}
+            with safe_open(model, framework="pt", device=args.device) as f:
+                for key in tqdm(f.keys()):
+                    value = f.get_tensor(key)
+                    _, key = replace_text_encoder_key(key)
+
+                    first_model_keys.add(key)
+
+                    if not is_unet_key(key) and args.unet_only:
+                        supplementary_key_ratios[key] = 1.0  # use first model's value for VAE or TextEncoder
+                        continue
+
+                    value = ratio * value.to(dtype)  # first model's value * ratio
+                    merged_sd[key] = value
+
+            print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
+            continue
+
+        # load other models
+        print(f"Loading model {model}, ratio = {ratio}...")
+
+        with safe_open(model, framework="pt", device=args.device) as f:
+            model_keys = f.keys()
+            for key in tqdm(model_keys):
+                _, new_key = replace_text_encoder_key(key)
+                if new_key not in merged_sd:
+                    if args.show_skipped and new_key not in first_model_keys:
+                        print(f"Skip: {new_key}")
+                    continue
+
+                value = f.get_tensor(key)
+                merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype)
+
+            # enumerate keys not in this model
+            model_keys = set(model_keys)
+            for key in merged_sd.keys():
+                if key in model_keys:
+                    continue
+                print(f"Key {key} not in model {model}, use first model's value")
+                if key in supplementary_key_ratios:
+                    supplementary_key_ratios[key] += ratio
+                else:
+                    supplementary_key_ratios[key] = ratio
+
+    # add supplementary keys' value (including VAE and TextEncoder)
+    if len(supplementary_key_ratios) > 0:
+        print("add first model's value")
+        with safe_open(args.models[0], framework="pt", device=args.device) as f:
+            for key in tqdm(f.keys()):
+                _, new_key = replace_text_encoder_key(key)
+                if new_key not in supplementary_key_ratios:
+                    continue
+
+                if is_unet_key(new_key):  # not VAE or TextEncoder
+                    print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
+
+                value = f.get_tensor(key)  # original key
+
+                if new_key not in merged_sd:
+                    merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype)
+                else:
+                    merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype)
+
+    # save
+    output_file = args.output
+    if not output_file.endswith(".safetensors"):
+        output_file = output_file + ".safetensors"
+
+    print(f"Saving to {output_file}...")
+
+    # convert to save_dtype
+    for k in merged_sd.keys():
+        merged_sd[k] = merged_sd[k].to(save_dtype)
+
+    save_file(merged_sd, output_file)
+
+    print("Done!")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Merge models")
+    parser.add_argument("--models", nargs="+", type=str, help="Models to merge")
+    parser.add_argument("--output", type=str, help="Output model")
+    parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0")
+    parser.add_argument("--unet_only", action="store_true", help="Only merge unet")
+    parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu")
+    parser.add_argument(
+        "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float"
+    )
+    parser.add_argument(
+        "--saving_precision",
+        type=str,
+        default="float",
+        choices=["float", "fp16", "bf16"],
+        help="Saving precision, default is float",
+    )
+    parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)")
+
+    args = parser.parse_args()
+    merge(args)
diff --git a/external/llite/tools/original_control_net.py b/external/llite/tools/original_control_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd47bd76adc294244782b86f1c43cf037b5c0918
--- /dev/null
+++ b/external/llite/tools/original_control_net.py
@@ -0,0 +1,337 @@
+from typing import List, NamedTuple, Any
+import numpy as np
+import cv2
+import torch
+from safetensors.torch import load_file
+
+from library.original_unet import UNet2DConditionModel, SampleOutput
+
+import library.model_util as model_util
+
+
+class ControlNetInfo(NamedTuple):
+    unet: Any
+    net: Any
+    prep: Any
+    weight: float
+    ratio: float
+
+
+class ControlNet(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+        # make control model
+        self.control_model = torch.nn.Module()
+
+        dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
+        zero_convs = torch.nn.ModuleList()
+        for i, dim in enumerate(dims):
+            sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
+            zero_convs.append(sub_list)
+        self.control_model.add_module("zero_convs", zero_convs)
+
+        middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
+        self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
+
+        dims = [16, 16, 32, 32, 96, 96, 256, 320]
+        strides = [1, 1, 2, 1, 2, 1, 2, 1]
+        prev_dim = 3
+        input_hint_block = torch.nn.Sequential()
+        for i, (dim, stride) in enumerate(zip(dims, strides)):
+            input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
+            if i < len(dims) - 1:
+                input_hint_block.append(torch.nn.SiLU())
+            prev_dim = dim
+        self.control_model.add_module("input_hint_block", input_hint_block)
+
+
+def load_control_net(v2, unet, model):
+    device = unet.device
+
+    # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
+    # state dictを読み込む
+    print(f"ControlNet: loading control SD model : {model}")
+
+    if model_util.is_safetensors(model):
+        ctrl_sd_sd = load_file(model)
+    else:
+        ctrl_sd_sd = torch.load(model, map_location="cpu")
+        ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
+
+    # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
+    is_difference = "difference" in ctrl_sd_sd
+    print("ControlNet: loading difference:", is_difference)
+
+    # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
+    # またTransfer Controlの元weightとなる
+    ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
+
+    # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
+    for key in list(ctrl_unet_sd_sd.keys()):
+        ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
+
+    zero_conv_sd = {}
+    for key in list(ctrl_sd_sd.keys()):
+        if key.startswith("control_"):
+            unet_key = "model.diffusion_" + key[len("control_") :]
+            if unet_key not in ctrl_unet_sd_sd:  # zero conv
+                zero_conv_sd[key] = ctrl_sd_sd[key]
+                continue
+            if is_difference:  # Transfer Control
+                ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
+            else:
+                ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
+
+    unet_config = model_util.create_unet_diffusers_config(v2)
+    ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config)  # DiffUsers版ControlNetのstate dict
+
+    # ControlNetのU-Netを作成する
+    ctrl_unet = UNet2DConditionModel(**unet_config)
+    info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
+    print("ControlNet: loading Control U-Net:", info)
+
+    # U-Net以外のControlNetを作成する
+    # TODO support middle only
+    ctrl_net = ControlNet()
+    info = ctrl_net.load_state_dict(zero_conv_sd)
+    print("ControlNet: loading ControlNet:", info)
+
+    ctrl_unet.to(unet.device, dtype=unet.dtype)
+    ctrl_net.to(unet.device, dtype=unet.dtype)
+    return ctrl_unet, ctrl_net
+
+
+def load_preprocess(prep_type: str):
+    if prep_type is None or prep_type.lower() == "none":
+        return None
+
+    if prep_type.startswith("canny"):
+        args = prep_type.split("_")
+        th1 = int(args[1]) if len(args) >= 2 else 63
+        th2 = int(args[2]) if len(args) >= 3 else 191
+
+        def canny(img):
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+            return cv2.Canny(img, th1, th2)
+
+        return canny
+
+    print("Unsupported prep type:", prep_type)
+    return None
+
+
+def preprocess_ctrl_net_hint_image(image):
+    image = np.array(image).astype(np.float32) / 255.0
+    # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
+    # image = image[:, :, ::-1].copy()                         # rgb to bgr
+    image = image[None].transpose(0, 3, 1, 2)  # nchw
+    image = torch.from_numpy(image)
+    return image  # 0 to 1
+
+
+def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
+    guided_hints = []
+    for i, cnet_info in enumerate(control_nets):
+        # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
+        b_hints = []
+        if len(hints) == 1:  # すべて同じ画像をhintとして使う
+            hint = hints[0]
+            if cnet_info.prep is not None:
+                hint = cnet_info.prep(hint)
+            hint = preprocess_ctrl_net_hint_image(hint)
+            b_hints = [hint for _ in range(b_size)]
+        else:
+            for bi in range(b_size):
+                hint = hints[(bi * len(control_nets) + i) % len(hints)]
+                if cnet_info.prep is not None:
+                    hint = cnet_info.prep(hint)
+                hint = preprocess_ctrl_net_hint_image(hint)
+                b_hints.append(hint)
+        b_hints = torch.cat(b_hints, dim=0)
+        b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
+
+        guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
+        guided_hints.append(guided_hint)
+    return guided_hints
+
+
+def call_unet_and_control_net(
+    step,
+    num_latent_input,
+    original_unet,
+    control_nets: List[ControlNetInfo],
+    guided_hints,
+    current_ratio,
+    sample,
+    timestep,
+    encoder_hidden_states,
+    encoder_hidden_states_for_control_net,
+):
+    # ControlNet
+    # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
+    cnet_cnt = len(control_nets)
+    cnet_idx = step % cnet_cnt
+    cnet_info = control_nets[cnet_idx]
+
+    # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
+    if cnet_info.ratio < current_ratio:
+        return original_unet(sample, timestep, encoder_hidden_states)
+
+    guided_hint = guided_hints[cnet_idx]
+    guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
+    outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
+    outs = [o * cnet_info.weight for o in outs]
+
+    # U-Net
+    return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
+
+
+"""
+  # これはmergeのバージョン
+  # ControlNet
+  cnet_outs_list = []
+  for i, cnet_info in enumerate(control_nets):
+    # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
+    if cnet_info.ratio < current_ratio:
+      continue
+    guided_hint = guided_hints[i]
+    outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
+    for i in range(len(outs)):
+      outs[i] *= cnet_info.weight
+
+    cnet_outs_list.append(outs)
+
+  count = len(cnet_outs_list)
+  if count == 0:
+    return original_unet(sample, timestep, encoder_hidden_states)
+
+  # sum of controlnets
+  for i in range(1, count):
+    cnet_outs_list[0] += cnet_outs_list[i]
+
+  # U-Net
+  return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
+"""
+
+
+def unet_forward(
+    is_control_net,
+    control_net: ControlNet,
+    unet: UNet2DConditionModel,
+    guided_hint,
+    ctrl_outs,
+    sample,
+    timestep,
+    encoder_hidden_states,
+):
+    # copy from UNet2DConditionModel
+    default_overall_up_factor = 2**unet.num_upsamplers
+
+    forward_upsample_size = False
+    upsample_size = None
+
+    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+        print("Forward upsample size to force interpolation output size.")
+        forward_upsample_size = True
+
+    # 1. time
+    timesteps = timestep
+    if not torch.is_tensor(timesteps):
+        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+        # This would be a good case for the `match` statement (Python 3.10+)
+        is_mps = sample.device.type == "mps"
+        if isinstance(timestep, float):
+            dtype = torch.float32 if is_mps else torch.float64
+        else:
+            dtype = torch.int32 if is_mps else torch.int64
+        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+    elif len(timesteps.shape) == 0:
+        timesteps = timesteps[None].to(sample.device)
+
+    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+    timesteps = timesteps.expand(sample.shape[0])
+
+    t_emb = unet.time_proj(timesteps)
+
+    # timesteps does not contain any weights and will always return f32 tensors
+    # but time_embedding might actually be running in fp16. so we need to cast here.
+    # there might be better ways to encapsulate this.
+    t_emb = t_emb.to(dtype=unet.dtype)
+    emb = unet.time_embedding(t_emb)
+
+    outs = []  # output of ControlNet
+    zc_idx = 0
+
+    # 2. pre-process
+    sample = unet.conv_in(sample)
+    if is_control_net:
+        sample += guided_hint
+        outs.append(control_net.control_model.zero_convs[zc_idx][0](sample))  # , emb, encoder_hidden_states))
+        zc_idx += 1
+
+    # 3. down
+    down_block_res_samples = (sample,)
+    for downsample_block in unet.down_blocks:
+        if downsample_block.has_cross_attention:
+            sample, res_samples = downsample_block(
+                hidden_states=sample,
+                temb=emb,
+                encoder_hidden_states=encoder_hidden_states,
+            )
+        else:
+            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+        if is_control_net:
+            for rs in res_samples:
+                outs.append(control_net.control_model.zero_convs[zc_idx][0](rs))  # , emb, encoder_hidden_states))
+                zc_idx += 1
+
+        down_block_res_samples += res_samples
+
+    # 4. mid
+    sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+    if is_control_net:
+        outs.append(control_net.control_model.middle_block_out[0](sample))
+        return outs
+
+    if not is_control_net:
+        sample += ctrl_outs.pop()
+
+    # 5. up
+    for i, upsample_block in enumerate(unet.up_blocks):
+        is_final_block = i == len(unet.up_blocks) - 1
+
+        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+        if not is_control_net and len(ctrl_outs) > 0:
+            res_samples = list(res_samples)
+            apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
+            ctrl_outs = ctrl_outs[: -len(res_samples)]
+            for j in range(len(res_samples)):
+                res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
+            res_samples = tuple(res_samples)
+
+        # if we have not reached the final block and need to forward the
+        # upsample size, we do it here
+        if not is_final_block and forward_upsample_size:
+            upsample_size = down_block_res_samples[-1].shape[2:]
+
+        if upsample_block.has_cross_attention:
+            sample = upsample_block(
+                hidden_states=sample,
+                temb=emb,
+                res_hidden_states_tuple=res_samples,
+                encoder_hidden_states=encoder_hidden_states,
+                upsample_size=upsample_size,
+            )
+        else:
+            sample = upsample_block(
+                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+            )
+    # 6. post-process
+    sample = unet.conv_norm_out(sample)
+    sample = unet.conv_act(sample)
+    sample = unet.conv_out(sample)
+
+    return SampleOutput(sample=sample)
diff --git a/external/llite/tools/resize_images_to_resolution.py b/external/llite/tools/resize_images_to_resolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d3224c4e28aaad71113e3bab8140da78a69bc2b
--- /dev/null
+++ b/external/llite/tools/resize_images_to_resolution.py
@@ -0,0 +1,128 @@
+import glob
+import os
+import cv2
+import argparse
+import shutil
+import math
+from PIL import Image
+import numpy as np
+
+
+def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
+  # Split the max_resolution string by "," and strip any whitespaces
+  max_resolutions = [res.strip() for res in max_resolution.split(',')]
+
+  # # Calculate max_pixels from max_resolution string
+  # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
+
+  # Create destination folder if it does not exist
+  if not os.path.exists(dst_img_folder):
+    os.makedirs(dst_img_folder)
+
+  # Select interpolation method
+  if interpolation == 'lanczos4':
+    cv2_interpolation = cv2.INTER_LANCZOS4
+  elif interpolation == 'cubic':
+    cv2_interpolation = cv2.INTER_CUBIC
+  else:
+    cv2_interpolation = cv2.INTER_AREA
+
+  # Iterate through all files in src_img_folder
+  img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp")                   # copy from train_util.py
+  for filename in os.listdir(src_img_folder):
+    # Check if the image is png, jpg or webp etc...
+    if not filename.endswith(img_exts):
+      # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
+      shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
+      continue
+
+    # Load image
+    # img = cv2.imread(os.path.join(src_img_folder, filename))
+    image = Image.open(os.path.join(src_img_folder, filename))
+    if not image.mode == "RGB":
+      image = image.convert("RGB")
+    img = np.array(image, np.uint8)
+
+    base, _ = os.path.splitext(filename)
+    for max_resolution in max_resolutions:
+      # Calculate max_pixels from max_resolution string
+      max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
+
+      # Calculate current number of pixels
+      current_pixels = img.shape[0] * img.shape[1]
+
+      # Check if the image needs resizing
+      if current_pixels > max_pixels:
+        # Calculate scaling factor
+        scale_factor = max_pixels / current_pixels
+
+        # Calculate new dimensions
+        new_height = int(img.shape[0] * math.sqrt(scale_factor))
+        new_width = int(img.shape[1] * math.sqrt(scale_factor))
+
+        # Resize image
+        img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
+      else:
+        new_height, new_width = img.shape[0:2]
+
+      # Calculate the new height and width that are divisible by divisible_by (with/without resizing)
+      new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
+      new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
+
+      # Center crop the image to the calculated dimensions
+      y = int((img.shape[0] - new_height) / 2)
+      x = int((img.shape[1] - new_width) / 2)
+      img = img[y:y + new_height, x:x + new_width]
+
+      # Split filename into base and extension
+      new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
+
+      # Save resized image in dst_img_folder
+      # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
+      image = Image.fromarray(img)
+      image.save(os.path.join(dst_img_folder, new_filename), quality=100)
+
+      proc = "Resized" if current_pixels > max_pixels else "Saved"
+      print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
+
+    # If other files with same basename, copy them with resolution suffix
+    if copy_associated_files:
+      asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
+      for asoc_file in asoc_files:
+        ext = os.path.splitext(asoc_file)[1]
+        if ext in img_exts:
+          continue
+        for max_resolution in max_resolutions:
+          new_asoc_file = base + '+' + max_resolution + ext
+          print(f"Copy {asoc_file} as {new_asoc_file}")
+          shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
+
+
+def setup_parser() -> argparse.ArgumentParser:
+  parser = argparse.ArgumentParser(
+      description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
+  parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
+  parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
+  parser.add_argument('--max_resolution', type=str,
+                      help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
+  parser.add_argument('--divisible_by', type=int,
+                      help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
+  parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
+                      default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
+  parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
+  parser.add_argument('--copy_associated_files', action='store_true',
+                      help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
+
+  return parser
+
+
+def main():
+  parser = setup_parser()
+
+  args = parser.parse_args()
+  resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
+                args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
+
+
+if __name__ == '__main__':
+  main()
diff --git a/external/llite/tools/show_metadata.py b/external/llite/tools/show_metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ca7b1c8aec9a5a8047925475b23b79d6bb21f5
--- /dev/null
+++ b/external/llite/tools/show_metadata.py
@@ -0,0 +1,19 @@
+import json
+import argparse
+from safetensors import safe_open
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--model", type=str, required=True)
+args = parser.parse_args()
+
+with safe_open(args.model, framework="pt") as f:
+    metadata = f.metadata()
+
+if metadata is None:
+    print("No metadata found")
+else:
+    # metadata is json dict, but not pretty printed
+    # sort by key and pretty print
+    print(json.dumps(metadata, indent=4, sort_keys=True))
+
+    
\ No newline at end of file
diff --git a/inference.py b/inference.py
index f6e5f5e6383adab77b083a09eb9e8193083da0aa..fadcce1d711fe0b4df257ac71916677af866be9c 100644
--- a/inference.py
+++ b/inference.py
@@ -468,24 +468,46 @@ def img2img(task: Task):
 
     width, height = get_intermediate_dimension(task)
 
-    lora_patcher = lora_style.get_patcher(
-        [img2img_pipe.pipe, high_res.pipe], task.get_style()
-    )
-    lora_patcher.patch()
-
     torch.manual_seed(task.get_seed())
 
-    kwargs = {
-        "prompt": prompt,
-        "imageUrl": task.get_imageUrl(),
-        "negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
-        "num_inference_steps": task.get_steps(),
-        "width": width,
-        "height": height,
-        **task.i2i_kwargs(),
-        **lora_patcher.kwargs(),
-    }
-    images, has_nsfw = img2img_pipe.process(**kwargs)
+    if get_is_sdxl():
+        # we run lineart for img2img
+        controlnet.load_model("linearart")
+
+        lora_patcher = lora_style.get_patcher(
+            [controlnet.pipe2, high_res.pipe], task.get_style()
+        )
+        lora_patcher.patch()
+
+        kwargs = {
+            "imageUrl": task.get_imageUrl(),
+            "seed": task.get_seed(),
+            "num_inference_steps": task.get_steps(),
+            "width": width,
+            "height": height,
+            "prompt": prompt,
+            "negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
+            **task.cnl_kwargs(),
+            "adapter_conditioning_scale": 0.3,
+        }
+        images, has_nsfw = controlnet.process(**kwargs)
+    else:
+        lora_patcher = lora_style.get_patcher(
+            [img2img_pipe.pipe, high_res.pipe], task.get_style()
+        )
+        lora_patcher.patch()
+
+        kwargs = {
+            "prompt": prompt,
+            "imageUrl": task.get_imageUrl(),
+            "negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
+            "num_inference_steps": task.get_steps(),
+            "width": width,
+            "height": height,
+            **task.i2i_kwargs(),
+            **lora_patcher.kwargs(),
+        }
+        images, has_nsfw = img2img_pipe.process(**kwargs)
 
     if task.get_high_res_fix():
         kwargs = {
diff --git a/internals/pipelines/controlnets.py b/internals/pipelines/controlnets.py
index d0803c7abf3a537b2b676bb03d6ec5ecab57effc..2bcb89334a5a00f8ecd702be08687529a968e3d1 100644
--- a/internals/pipelines/controlnets.py
+++ b/internals/pipelines/controlnets.py
@@ -12,6 +12,7 @@ from controlnet_aux import (
 from diffusers import (
     ControlNetModel,
     DiffusionPipeline,
+    EulerAncestralDiscreteScheduler,
     StableDiffusionAdapterPipeline,
     StableDiffusionControlNetImg2ImgPipeline,
     StableDiffusionControlNetPipeline,
@@ -200,6 +201,10 @@ class ControlNet(AbstractPipeline):
                 pipe.enable_vae_tiling()
                 pipe.enable_vae_slicing()
                 pipe.enable_xformers_memory_efficient_attention()
+                # this scheduler produces good outputs for t2i adapters
+                pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
+                    pipe.scheduler.config
+                )
             else:
                 pipe.enable_xformers_memory_efficient_attention()
             return pipe
@@ -394,7 +399,7 @@ class ControlNet(AbstractPipeline):
         sdxl_args = (
             {
                 "guidance_scale": 6,
-                "adapter_conditioning_scale": 0.6,
+                "adapter_conditioning_scale": 1.0,
                 "adapter_conditioning_factor": 1.0,
             }
             if get_is_sdxl()
@@ -440,8 +445,8 @@ class ControlNet(AbstractPipeline):
         sdxl_args = (
             {
                 "guidance_scale": 6,
-                "adapter_conditioning_scale": 0.5,
-                "adapter_conditioning_factor": 0.9,
+                "adapter_conditioning_scale": 1.0,
+                "adapter_conditioning_factor": 1.0,
             }
             if get_is_sdxl()
             else {}
@@ -479,9 +484,12 @@ class ControlNet(AbstractPipeline):
         return image
 
     @staticmethod
-    def linearart_condition_image(image: Image.Image) -> Image.Image:
+    def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
         processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
-        image = processor.__call__(input_image=image)
+        if get_is_sdxl():
+            kwargs = {"detect_resolution": 384, **kwargs}
+
+        image = processor.__call__(input_image=image, **kwargs)
         return image
 
     @staticmethod
diff --git a/internals/pipelines/high_res.py b/internals/pipelines/high_res.py
index 3eccfdc2deaed6704bfc58ab975bdde4f7331b80..ad06b8f34c69038ae9fae9832e199d44113391c0 100644
--- a/internals/pipelines/high_res.py
+++ b/internals/pipelines/high_res.py
@@ -6,7 +6,7 @@ from PIL import Image
 from internals.data.result import Result
 from internals.pipelines.commons import AbstractPipeline, Img2Img
 from internals.util.cache import clear_cuda_and_gc
-from internals.util.config import get_base_dimension, get_model_dir
+from internals.util.config import get_base_dimension, get_is_sdxl, get_model_dir
 
 
 class HighRes(AbstractPipeline):
@@ -60,4 +60,59 @@ class HighRes(AbstractPipeline):
         firstpass_width = math.ceil(scale * target_width / 64) * 64
         firstpass_height = math.ceil(scale * target_height / 64) * 64
 
+        print("Pass1", firstpass_width, firstpass_height)
+
+        if get_is_sdxl():
+            firstpass_width, firstpass_height = HighRes.find_closest_sdxl_aspect_ratio(
+                firstpass_width, firstpass_height
+            )
+
+        print("Pass2", firstpass_width, firstpass_height)
         return firstpass_width, firstpass_height
+
+    @staticmethod
+    def find_closest_sdxl_aspect_ratio(target_width: int, target_height: int):
+        target_ratio = target_width / target_height
+        closest_ratio = ""
+        min_difference = float("inf")
+
+        for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
+            ratio = width / height
+            difference = abs(target_ratio - ratio)
+
+            if difference < min_difference:
+                min_difference = difference
+                closest_ratio = ratio_str
+
+        new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
+        return new_width, new_height
+
+
+SD_XL_BASE_RATIOS = {
+    "0.5": (704, 1408),
+    "0.52": (704, 1344),
+    "0.57": (768, 1344),
+    "0.6": (768, 1280),
+    "0.68": (832, 1216),
+    "0.72": (832, 1152),
+    "0.78": (896, 1152),
+    "0.82": (896, 1088),
+    "0.88": (960, 1088),
+    "0.94": (960, 1024),
+    "1.0": (1024, 1024),
+    "1.07": (1024, 960),
+    "1.13": (1088, 960),
+    "1.21": (1088, 896),
+    "1.29": (1152, 896),
+    "1.38": (1152, 832),
+    "1.46": (1216, 832),
+    "1.67": (1280, 768),
+    "1.75": (1344, 768),
+    "1.91": (1344, 704),
+    "2.0": (1408, 704),
+    "2.09": (1472, 704),
+    "2.4": (1536, 640),
+    "2.5": (1600, 640),
+    "2.89": (1664, 576),
+    "3.0": (1728, 576),
+}
diff --git a/internals/pipelines/realtime_draw.py b/internals/pipelines/realtime_draw.py
index 58f91f787d4d8a2c091cfc6bd7da327a8422677f..a5c36a9a8e9cdaa37fa5e075cba5f3e931e5d25b 100644
--- a/internals/pipelines/realtime_draw.py
+++ b/internals/pipelines/realtime_draw.py
@@ -7,7 +7,9 @@ from PIL import Image
 import internals.util.image as ImageUtil
 from internals.pipelines.commons import AbstractPipeline
 from internals.pipelines.controlnets import ControlNet
-from internals.util.config import get_hf_cache_dir
+from internals.pipelines.high_res import HighRes
+from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline
+from internals.util.config import get_base_dimension, get_hf_cache_dir, get_is_sdxl
 
 
 class RealtimeDraw(AbstractPipeline):
@@ -15,28 +17,38 @@ class RealtimeDraw(AbstractPipeline):
         if hasattr(self, "pipe"):
             return
 
-        self.__controlnet_scribble = ControlNetModel.from_pretrained(
-            "lllyasviel/control_v11p_sd15_scribble",
-            torch_dtype=torch.float16,
-            cache_dir=get_hf_cache_dir(),
-        )
-
-        self.__controlnet_seg = ControlNetModel.from_pretrained(
-            "lllyasviel/control_v11p_sd15_seg",
-            torch_dtype=torch.float16,
-            cache_dir=get_hf_cache_dir(),
-        )
-
-        kwargs = {**pipeline.pipe.components}  # pyright: ignore
-        kwargs.pop("image_encoder", None)
-        self.pipe = StableDiffusionControlNetImg2ImgPipeline(
-            **kwargs, controlnet=self.__controlnet_seg
-        ).to("cuda")
-        self.pipe.safety_checker = None
-        self.pipe2 = StableDiffusionControlNetImg2ImgPipeline(
-            **kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg]
-        ).to("cuda")
-        self.pipe2.safety_checker = None
+        if get_is_sdxl():
+            lite_pipe = SDXLLLiteImg2ImgPipeline()
+            lite_pipe.load(
+                pipeline,
+                [
+                    "https://s3.ap-south-1.amazonaws.com/autodraft.model.assets/models/replicate-xl-llite.safetensors"
+                ],
+            )
+            self.pipe = lite_pipe
+        else:
+            self.__controlnet_scribble = ControlNetModel.from_pretrained(
+                "lllyasviel/control_v11p_sd15_scribble",
+                torch_dtype=torch.float16,
+                cache_dir=get_hf_cache_dir(),
+            )
+
+            self.__controlnet_seg = ControlNetModel.from_pretrained(
+                "lllyasviel/control_v11p_sd15_seg",
+                torch_dtype=torch.float16,
+                cache_dir=get_hf_cache_dir(),
+            )
+
+            kwargs = {**pipeline.pipe.components}  # pyright: ignore
+            kwargs.pop("image_encoder", None)
+            self.pipe = StableDiffusionControlNetImg2ImgPipeline(
+                **kwargs, controlnet=self.__controlnet_seg
+            ).to("cuda")
+            self.pipe.safety_checker = None
+            self.pipe2 = StableDiffusionControlNetImg2ImgPipeline(
+                **kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg]
+            ).to("cuda")
+            self.pipe2.safety_checker = None
 
     def process_seg(
         self,
@@ -45,6 +57,9 @@ class RealtimeDraw(AbstractPipeline):
         negative_prompt: str,
         seed: int,
     ):
+        if get_is_sdxl():
+            raise Exception("SDXL is not supported for this method")
+
         torch.manual_seed(seed)
 
         image = ImageUtil.resize_image(image, 512)
@@ -71,35 +86,55 @@ class RealtimeDraw(AbstractPipeline):
     ):
         torch.manual_seed(seed)
 
+        b_dimen = get_base_dimension()
+
         if not image:
-            size = (512, 512)
+            size = (b_dimen, b_dimen)
             if image2:
                 size = image2.size
             image = Image.new("RGB", size, color=0)
 
         if not image2:
-            size = (512, 512)
+            size = (b_dimen, b_dimen)
             if image:
                 size = image.size
             image2 = Image.new("RGB", size, color=0)
 
-        image = ImageUtil.resize_image(image, 512)
-
-        scribble = ControlNet.scribble_image(image)
-
-        image2 = ImageUtil.resize_image(image2, 512)
-
-        img = self.pipe2.__call__(
-            image=image,
-            control_image=[scribble, image2],
-            prompt=prompt,
-            num_inference_steps=15,
-            negative_prompt=negative_prompt,
-            guidance_scale=10,
-            strength=0.9,
-            width=image.size[0],
-            height=image.size[1],
-            controlnet_conditioning_scale=[1.0, 0.8],
-        ).images[0]
+        if get_is_sdxl():
+            size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1])
+            image = image.resize(size)
+
+            images = self.pipe.__call__(
+                image=image,
+                condition_image=image,
+                negative_prompt=negative_prompt,
+                prompt=prompt,
+                seed=seed,
+                num_inference_steps=10,
+                width=image.size[0],
+                height=image.size[1],
+            )
+            img = images[0]
+        else:
+            image = ImageUtil.resize_image(image, b_dimen)
+
+            scribble = ControlNet.scribble_image(image)
+
+            image2 = ImageUtil.resize_image(image2, b_dimen)
+
+            img = self.pipe2.__call__(
+                image=image,
+                control_image=[scribble, image2],
+                prompt=prompt,
+                num_inference_steps=15,
+                negative_prompt=negative_prompt,
+                guidance_scale=10,
+                strength=0.9,
+                width=image.size[0],
+                height=image.size[1],
+                controlnet_conditioning_scale=[1.0, 0.8],
+            ).images[0]
+
+        img = ImageUtil.resize_image(img, 512)
 
         return img
diff --git a/internals/pipelines/sdxl_llite_pipeline.py b/internals/pipelines/sdxl_llite_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c0baa52a460fad76100ca2d01d85ad2371727de
--- /dev/null
+++ b/internals/pipelines/sdxl_llite_pipeline.py
@@ -0,0 +1,1689 @@
+import inspect
+import re
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple, Union
+
+import diffusers
+import numpy as np
+import PIL
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+    AutoencoderKL,
+    DDIMScheduler,
+    EulerDiscreteScheduler,
+    LCMScheduler,
+    LMSDiscreteScheduler,
+    PNDMScheduler,
+    StableDiffusionXLPipeline,
+)
+from diffusers.configuration_utils import FrozenDict
+from diffusers.utils.deprecation_utils import deprecate
+from einops import rearrange
+from PIL import Image
+from PIL.PngImagePlugin import PngInfo
+from safetensors.torch import load_file
+from tqdm import tqdm
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+import external.llite.library.model_util as model_util
+import external.llite.library.sdxl_model_util as sdxl_model_util
+import external.llite.library.sdxl_original_unet as sdxl_original_unet
+import external.llite.library.sdxl_train_util as sdxl_train_util
+import external.llite.library.train_util as train_util
+from external.llite.library.original_unet import FlashAttentionFunction
+from external.llite.library.sdxl_original_unet import InferSdxlUNet2DConditionModel
+from external.llite.networks.control_net_lllite import ControlNetLLLite
+from external.llite.networks.lora import LoRANetwork
+from internals.pipelines.commons import AbstractPipeline
+from internals.util.cache import clear_cuda_and_gc
+from internals.util.commons import download_file
+
+
+class PipelineLike:
+    def __init__(
+        self,
+        device,
+        vae: AutoencoderKL,
+        text_encoders: List[CLIPTextModel],
+        tokenizers: List[CLIPTokenizer],
+        unet: InferSdxlUNet2DConditionModel,
+        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+        clip_skip: int,
+    ):
+        super().__init__()
+        self.device = device
+        self.clip_skip = clip_skip
+
+        if (
+            hasattr(scheduler.config, "steps_offset")
+            and scheduler.config.steps_offset != 1
+        ):
+            deprecation_message = (
+                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+                " file"
+            )
+            deprecate(
+                "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
+            )
+            new_config = dict(scheduler.config)
+            new_config["steps_offset"] = 1
+            scheduler._internal_dict = FrozenDict(new_config)
+
+        if (
+            hasattr(scheduler.config, "clip_sample")
+            and scheduler.config.clip_sample is True
+        ):
+            deprecation_message = (
+                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+            )
+            deprecate(
+                "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
+            )
+            new_config = dict(scheduler.config)
+            new_config["clip_sample"] = False
+            scheduler._internal_dict = FrozenDict(new_config)
+
+        self.vae = vae
+        self.text_encoders = text_encoders
+        self.tokenizers = tokenizers
+        self.unet: InferSdxlUNet2DConditionModel = unet
+        self.scheduler = scheduler
+        self.safety_checker = None
+
+        self.clip_vision_model: CLIPVisionModelWithProjection = None
+        self.clip_vision_processor: CLIPImageProcessor = None
+        self.clip_vision_strength = 0.0
+
+        # Textual Inversion
+        self.token_replacements_list = []
+        for _ in range(len(self.text_encoders)):
+            self.token_replacements_list.append({})
+
+        # ControlNet # not supported yet
+        self.control_nets: List[ControlNetLLLite] = []
+        self.control_net_enabled = True  # control_netsが空ならTrueでもFalseでもControlNetは動作しない
+
+    # Textual Inversion
+    def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids):
+        self.token_replacements_list[text_encoder_index][
+            target_token_id
+        ] = rep_token_ids
+
+    def set_enable_control_net(self, en: bool):
+        self.control_net_enabled = en
+
+    def preprocess_image(self, image):
+        w, h = image.size
+        # resize to integer multiple of 32
+        w, h = map(lambda x: x - x % 32, (w, h))
+        image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+        image = np.array(image).astype(np.float32) / 255.0
+        image = image[None].transpose(0, 3, 1, 2)
+        image = torch.from_numpy(image)
+        return 2.0 * image - 1.0
+
+    def get_unweighted_text_embeddings(
+        self,
+        text_encoder: CLIPTextModel,
+        text_input: torch.Tensor,
+        chunk_length: int,
+        clip_skip: int,
+        eos: int,
+        pad: int,
+        no_boseos_middle: Optional[bool] = True,
+    ):
+        """
+        When the length of tokens is a multiple of the capacity of the text encoder,
+        it should be split into chunks and sent to the text encoder individually.
+        """
+        max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+        if max_embeddings_multiples > 1:
+            text_embeddings = []
+            pool = None
+            for i in range(max_embeddings_multiples):
+                # extract the i-th chunk
+                text_input_chunk = text_input[
+                    :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
+                ].clone()
+
+                # cover the head and the tail by the starting and the ending tokens
+                text_input_chunk[:, 0] = text_input[0, 0]
+                if pad == eos:  # v1
+                    text_input_chunk[:, -1] = text_input[0, -1]
+                else:  # v2
+                    for j in range(len(text_input_chunk)):
+                        # 最後に普通の文字がある
+                        if (
+                            text_input_chunk[j, -1] != eos
+                            and text_input_chunk[j, -1] != pad
+                        ):
+                            text_input_chunk[j, -1] = eos
+                        if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
+                            text_input_chunk[j, 1] = eos
+
+                # -2 is same for Text Encoder 1 and 2
+                enc_out = text_encoder(
+                    text_input_chunk, output_hidden_states=True, return_dict=True
+                )
+                text_embedding = enc_out["hidden_states"][-2]
+                if pool is None:
+                    # use 1st chunk, if provided
+                    pool = enc_out.get("text_embeds", None)
+                    if pool is not None:
+                        pool = train_util.pool_workaround(
+                            text_encoder,
+                            enc_out["last_hidden_state"],
+                            text_input_chunk,
+                            eos,
+                        )
+
+                if no_boseos_middle:
+                    if i == 0:
+                        # discard the ending token
+                        text_embedding = text_embedding[:, :-1]
+                    elif i == max_embeddings_multiples - 1:
+                        # discard the starting token
+                        text_embedding = text_embedding[:, 1:]
+                    else:
+                        # discard both starting and ending tokens
+                        text_embedding = text_embedding[:, 1:-1]
+
+                text_embeddings.append(text_embedding)
+            text_embeddings = torch.concat(text_embeddings, axis=1)
+        else:
+            enc_out = text_encoder(
+                text_input, output_hidden_states=True, return_dict=True
+            )
+            text_embeddings = enc_out["hidden_states"][-2]
+            # text encoder 1 doesn't return this
+            pool = enc_out.get("text_embeds", None)
+            if pool is not None:
+                pool = train_util.pool_workaround(
+                    text_encoder, enc_out["last_hidden_state"], text_input, eos
+                )
+        return text_embeddings, pool
+
+    def preprocess_mask(self, mask):
+        mask = mask.convert("L")
+        w, h = mask.size
+        # resize to integer multiple of 32
+        w, h = map(lambda x: x - x % 32, (w, h))
+        mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR)  # LANCZOS)
+        mask = np.array(mask).astype(np.float32) / 255.0
+        mask = np.tile(mask, (4, 1, 1))
+        mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?
+        mask = 1 - mask  # repaint white, keep black
+        mask = torch.from_numpy(mask)
+        return mask
+
+    def get_prompts_with_weights(
+        self,
+        tokenizer: CLIPTokenizer,
+        token_replacer,
+        prompt: List[str],
+        max_length: int,
+    ):
+        r"""
+        Tokenize a list of prompts and return its tokens with weights of each token.
+        No padding, starting or ending token is included.
+        """
+        tokens = []
+        weights = []
+        truncated = False
+
+        def parse_prompt_attention(text):
+            """
+            Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+            Accepted tokens are:
+            (abc) - increases attention to abc by a multiplier of 1.1
+            (abc:3.12) - increases attention to abc by a multiplier of 3.12
+            [abc] - decreases attention to abc by a multiplier of 1.1
+            \( - literal character '('
+            \[ - literal character '['
+            \) - literal character ')'
+            \] - literal character ']'
+            \\ - literal character '\'
+            anything else - just text
+            >>> parse_prompt_attention('normal text')
+            [['normal text', 1.0]]
+            >>> parse_prompt_attention('an (important) word')
+            [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+            >>> parse_prompt_attention('(unbalanced')
+            [['unbalanced', 1.1]]
+            >>> parse_prompt_attention('\(literal\]')
+            [['(literal]', 1.0]]
+            >>> parse_prompt_attention('(unnecessary)(parens)')
+            [['unnecessaryparens', 1.1]]
+            >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+            [['a ', 1.0],
+            ['house', 1.5730000000000004],
+            [' ', 1.1],
+            ['on', 1.0],
+            [' a ', 1.1],
+            ['hill', 0.55],
+            [', sun, ', 1.1],
+            ['sky', 1.4641000000000006],
+            ['.', 1.1]]
+            """
+
+            res = []
+            round_brackets = []
+            square_brackets = []
+
+            round_bracket_multiplier = 1.1
+            square_bracket_multiplier = 1 / 1.1
+
+            def multiply_range(start_position, multiplier):
+                for p in range(start_position, len(res)):
+                    res[p][1] *= multiplier
+
+            # keep break as separate token
+            text = text.replace("BREAK", "\\BREAK\\")
+            re_attention = re.compile(
+                r"""
+            \\\(|
+            \\\)|
+            \\\[|
+            \\]|
+            \\\\|
+            \\|
+            \(|
+            \[|
+            :([+-]?[.\d]+)\)|
+            \)|
+            ]|
+            [^\\()\[\]:]+|
+            :
+            """,
+                re.X,
+            )
+            for m in re_attention.finditer(text):
+                text = m.group(0)
+                weight = m.group(1)
+
+                if text.startswith("\\"):
+                    res.append([text[1:], 1.0])
+                elif text == "(":
+                    round_brackets.append(len(res))
+                elif text == "[":
+                    square_brackets.append(len(res))
+                elif weight is not None and len(round_brackets) > 0:
+                    multiply_range(round_brackets.pop(), float(weight))
+                elif text == ")" and len(round_brackets) > 0:
+                    multiply_range(round_brackets.pop(), round_bracket_multiplier)
+                elif text == "]" and len(square_brackets) > 0:
+                    multiply_range(square_brackets.pop(), square_bracket_multiplier)
+                else:
+                    res.append([text, 1.0])
+
+            for pos in round_brackets:
+                multiply_range(pos, round_bracket_multiplier)
+
+            for pos in square_brackets:
+                multiply_range(pos, square_bracket_multiplier)
+
+            if len(res) == 0:
+                res = [["", 1.0]]
+
+            # merge runs of identical weights
+            i = 0
+            while i + 1 < len(res):
+                if (
+                    res[i][1] == res[i + 1][1]
+                    and res[i][0].strip() != "BREAK"
+                    and res[i + 1][0].strip() != "BREAK"
+                ):
+                    res[i][0] += res[i + 1][0]
+                    res.pop(i + 1)
+                else:
+                    i += 1
+
+            return res
+
+        for text in prompt:
+            texts_and_weights = parse_prompt_attention(text)
+            text_token = []
+            text_weight = []
+            for word, weight in texts_and_weights:
+                if word.strip() == "BREAK":
+                    # pad until next multiple of tokenizer's max token length
+                    pad_len = tokenizer.model_max_length - (
+                        len(text_token) % tokenizer.model_max_length
+                    )
+                    print(f"BREAK pad_len: {pad_len}")
+                    for i in range(pad_len):
+                        # v2のときEOSをつけるべきかどうかわからないぜ
+                        # if i == 0:
+                        #     text_token.append(tokenizer.eos_token_id)
+                        # else:
+                        text_token.append(tokenizer.pad_token_id)
+                        text_weight.append(1.0)
+                    continue
+
+                # tokenize and discard the starting and the ending token
+                token = tokenizer(word).input_ids[1:-1]
+
+                token = token_replacer(token)  # for Textual Inversion
+
+                text_token += token
+                # copy the weight by length of token
+                text_weight += [weight] * len(token)
+                # stop if the text is too long (longer than truncation limit)
+                if len(text_token) > max_length:
+                    truncated = True
+                    break
+            # truncate
+            if len(text_token) > max_length:
+                truncated = True
+                text_token = text_token[:max_length]
+                text_weight = text_weight[:max_length]
+            tokens.append(text_token)
+            weights.append(text_weight)
+        if truncated:
+            print(
+                "warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
+            )
+        return tokens, weights
+
+    def pad_tokens_and_weights(
+        self,
+        tokens,
+        weights,
+        max_length,
+        bos,
+        eos,
+        pad,
+        no_boseos_middle=True,
+        chunk_length=77,
+    ):
+        r"""
+        Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+        """
+        max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+        weights_length = (
+            max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+        )
+        for i in range(len(tokens)):
+            tokens[i] = (
+                [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
+            )
+            if no_boseos_middle:
+                weights[i] = (
+                    [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+                )
+            else:
+                w = []
+                if len(weights[i]) == 0:
+                    w = [1.0] * weights_length
+                else:
+                    for j in range(max_embeddings_multiples):
+                        # weight for starting token in this chunk
+                        w.append(1.0)
+                        w += weights[i][
+                            j
+                            * (chunk_length - 2) : min(
+                                len(weights[i]), (j + 1) * (chunk_length - 2)
+                            )
+                        ]
+                        w.append(1.0)  # weight for ending token in this chunk
+                    w += [1.0] * (weights_length - len(w))
+                weights[i] = w[:]
+
+        return tokens, weights
+
+    def get_unweighted_text_embeddings(
+        self,
+        text_encoder: CLIPTextModel,
+        text_input: torch.Tensor,
+        chunk_length: int,
+        clip_skip: int,
+        eos: int,
+        pad: int,
+        no_boseos_middle: Optional[bool] = True,
+    ):
+        """
+        When the length of tokens is a multiple of the capacity of the text encoder,
+        it should be split into chunks and sent to the text encoder individually.
+        """
+        max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+        if max_embeddings_multiples > 1:
+            text_embeddings = []
+            pool = None
+            for i in range(max_embeddings_multiples):
+                # extract the i-th chunk
+                text_input_chunk = text_input[
+                    :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
+                ].clone()
+
+                # cover the head and the tail by the starting and the ending tokens
+                text_input_chunk[:, 0] = text_input[0, 0]
+                if pad == eos:  # v1
+                    text_input_chunk[:, -1] = text_input[0, -1]
+                else:  # v2
+                    for j in range(len(text_input_chunk)):
+                        # 最後に普通の文字がある
+                        if (
+                            text_input_chunk[j, -1] != eos
+                            and text_input_chunk[j, -1] != pad
+                        ):
+                            text_input_chunk[j, -1] = eos
+                        if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
+                            text_input_chunk[j, 1] = eos
+
+                # -2 is same for Text Encoder 1 and 2
+                enc_out = text_encoder(
+                    text_input_chunk, output_hidden_states=True, return_dict=True
+                )
+                text_embedding = enc_out["hidden_states"][-2]
+                if pool is None:
+                    # use 1st chunk, if provided
+                    pool = enc_out.get("text_embeds", None)
+                    if pool is not None:
+                        pool = train_util.pool_workaround(
+                            text_encoder,
+                            enc_out["last_hidden_state"],
+                            text_input_chunk,
+                            eos,
+                        )
+
+                if no_boseos_middle:
+                    if i == 0:
+                        # discard the ending token
+                        text_embedding = text_embedding[:, :-1]
+                    elif i == max_embeddings_multiples - 1:
+                        # discard the starting token
+                        text_embedding = text_embedding[:, 1:]
+                    else:
+                        # discard both starting and ending tokens
+                        text_embedding = text_embedding[:, 1:-1]
+
+                text_embeddings.append(text_embedding)
+            text_embeddings = torch.concat(text_embeddings, axis=1)
+        else:
+            enc_out = text_encoder(
+                text_input, output_hidden_states=True, return_dict=True
+            )
+            text_embeddings = enc_out["hidden_states"][-2]
+            # text encoder 1 doesn't return this
+            pool = enc_out.get("text_embeds", None)
+            if pool is not None:
+                pool = train_util.pool_workaround(
+                    text_encoder, enc_out["last_hidden_state"], text_input, eos
+                )
+        return text_embeddings, pool
+
+    def get_weighted_text_embeddings(
+        self,
+        tokenizer: CLIPTokenizer,
+        text_encoder: CLIPTextModel,
+        prompt: Union[str, List[str]],
+        uncond_prompt: Optional[Union[str, List[str]]] = None,
+        max_embeddings_multiples: Optional[int] = 1,
+        no_boseos_middle: Optional[bool] = False,
+        skip_parsing: Optional[bool] = False,
+        skip_weighting: Optional[bool] = False,
+        clip_skip=None,
+        token_replacer=None,
+        device=None,
+        **kwargs,
+    ):
+        max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+        if isinstance(prompt, str):
+            prompt = [prompt]
+
+        # split the prompts with "AND". each prompt must have the same number of splits
+        new_prompts = []
+        for p in prompt:
+            new_prompts.extend(p.split(" AND "))
+        prompt = new_prompts
+
+        if not skip_parsing:
+            prompt_tokens, prompt_weights = self.get_prompts_with_weights(
+                tokenizer, token_replacer, prompt, max_length - 2
+            )
+            if uncond_prompt is not None:
+                if isinstance(uncond_prompt, str):
+                    uncond_prompt = [uncond_prompt]
+                uncond_tokens, uncond_weights = self.get_prompts_with_weights(
+                    tokenizer, token_replacer, uncond_prompt, max_length - 2
+                )
+        else:
+            prompt_tokens = [
+                token[1:-1]
+                for token in tokenizer(
+                    prompt, max_length=max_length, truncation=True
+                ).input_ids
+            ]
+            prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+            if uncond_prompt is not None:
+                if isinstance(uncond_prompt, str):
+                    uncond_prompt = [uncond_prompt]
+                uncond_tokens = [
+                    token[1:-1]
+                    for token in tokenizer(
+                        uncond_prompt, max_length=max_length, truncation=True
+                    ).input_ids
+                ]
+                uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+        # round up the longest length of tokens to a multiple of (model_max_length - 2)
+        max_length = max([len(token) for token in prompt_tokens])
+        if uncond_prompt is not None:
+            max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+        max_embeddings_multiples = min(
+            max_embeddings_multiples,
+            (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
+        )
+        max_embeddings_multiples = max(1, max_embeddings_multiples)
+        max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+        # pad the length of tokens and weights
+        bos = tokenizer.bos_token_id
+        eos = tokenizer.eos_token_id
+        pad = tokenizer.pad_token_id
+        prompt_tokens, prompt_weights = self.pad_tokens_and_weights(
+            prompt_tokens,
+            prompt_weights,
+            max_length,
+            bos,
+            eos,
+            pad,
+            no_boseos_middle=no_boseos_middle,
+            chunk_length=tokenizer.model_max_length,
+        )
+        prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
+        if uncond_prompt is not None:
+            uncond_tokens, uncond_weights = self.pad_tokens_and_weights(
+                uncond_tokens,
+                uncond_weights,
+                max_length,
+                bos,
+                eos,
+                pad,
+                no_boseos_middle=no_boseos_middle,
+                chunk_length=tokenizer.model_max_length,
+            )
+            uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
+
+        # get the embeddings
+        text_embeddings, text_pool = self.get_unweighted_text_embeddings(
+            text_encoder,
+            prompt_tokens,
+            tokenizer.model_max_length,
+            clip_skip,
+            eos,
+            pad,
+            no_boseos_middle=no_boseos_middle,
+        )
+        prompt_weights = torch.tensor(
+            prompt_weights, dtype=text_embeddings.dtype, device=device
+        )
+        if uncond_prompt is not None:
+            uncond_embeddings, uncond_pool = self.get_unweighted_text_embeddings(
+                text_encoder,
+                uncond_tokens,
+                tokenizer.model_max_length,
+                clip_skip,
+                eos,
+                pad,
+                no_boseos_middle=no_boseos_middle,
+            )
+            uncond_weights = torch.tensor(
+                uncond_weights, dtype=uncond_embeddings.dtype, device=device
+            )
+
+        # assign weights to the prompts and normalize in the sense of mean
+        # TODO: should we normalize by chunk or in a whole (current implementation)?
+        # →全体でいいんじゃないかな
+        if (not skip_parsing) and (not skip_weighting):
+            previous_mean = (
+                text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+            )
+            text_embeddings *= prompt_weights.unsqueeze(-1)
+            current_mean = (
+                text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+            )
+            text_embeddings *= (
+                (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+            )
+            if uncond_prompt is not None:
+                previous_mean = (
+                    uncond_embeddings.float()
+                    .mean(axis=[-2, -1])
+                    .to(uncond_embeddings.dtype)
+                )
+                uncond_embeddings *= uncond_weights.unsqueeze(-1)
+                current_mean = (
+                    uncond_embeddings.float()
+                    .mean(axis=[-2, -1])
+                    .to(uncond_embeddings.dtype)
+                )
+                uncond_embeddings *= (
+                    (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+                )
+
+        if uncond_prompt is not None:
+            return (
+                text_embeddings,
+                text_pool,
+                uncond_embeddings,
+                uncond_pool,
+                prompt_tokens,
+            )
+        return text_embeddings, text_pool, None, None, prompt_tokens
+
+    def get_token_replacer(self, tokenizer):
+        tokenizer_index = self.tokenizers.index(tokenizer)
+        token_replacements = self.token_replacements_list[tokenizer_index]
+
+        def replace_tokens(tokens):
+            # print("replace_tokens", tokens, "=>", token_replacements)
+            if isinstance(tokens, torch.Tensor):
+                tokens = tokens.tolist()
+
+            new_tokens = []
+            for token in tokens:
+                if token in token_replacements:
+                    replacement = token_replacements[token]
+                    new_tokens.extend(replacement)
+                else:
+                    new_tokens.append(token)
+            return new_tokens
+
+        return replace_tokens
+
+    def set_control_nets(self, ctrl_nets):
+        self.control_nets = ctrl_nets
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        init_image: Union[
+            torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
+        ] = None,
+        mask_image: Union[
+            torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
+        ] = None,
+        height: int = 1024,
+        width: int = 1024,
+        original_height: int = None,
+        original_width: int = None,
+        original_height_negative: int = None,
+        original_width_negative: int = None,
+        crop_top: int = 0,
+        crop_left: int = 0,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        negative_scale: float = None,
+        strength: float = 0.8,
+        # num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[torch.Generator] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        max_embeddings_multiples: Optional[int] = 3,
+        output_type: Optional[str] = "pil",
+        vae_batch_size: float = None,
+        return_latents: bool = False,
+        # return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        is_cancelled_callback: Optional[Callable[[], bool]] = None,
+        callback_steps: Optional[int] = 1,
+        img2img_noise=None,
+        clip_guide_images=None,
+        **kwargs,
+    ):
+        # TODO support secondary prompt
+        num_images_per_prompt = 1  # fixed because already prompt is repeated
+
+        if isinstance(prompt, str):
+            batch_size = 1
+            prompt = [prompt]
+        elif isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            raise ValueError(
+                f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+            )
+        reginonal_network = " AND " in prompt[0]
+
+        vae_batch_size = (
+            batch_size
+            if vae_batch_size is None
+            else (
+                int(vae_batch_size)
+                if vae_batch_size >= 1
+                else max(1, int(batch_size * vae_batch_size))
+            )
+        )
+
+        if strength < 0 or strength > 1:
+            raise ValueError(
+                f"The value of strength should in [0.0, 1.0] but is {strength}"
+            )
+
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(
+                f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+            )
+
+        if (callback_steps is None) or (
+            callback_steps is not None
+            and (not isinstance(callback_steps, int) or callback_steps <= 0)
+        ):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        # get prompt text embeddings
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        if not do_classifier_free_guidance and negative_scale is not None:
+            print(f"negative_scale is ignored if guidance scalle <= 1.0")
+            negative_scale = None
+
+        # get unconditional embeddings for classifier free guidance
+        if negative_prompt is None:
+            negative_prompt = [""] * batch_size
+        elif isinstance(negative_prompt, str):
+            negative_prompt = [negative_prompt] * batch_size
+        if batch_size != len(negative_prompt):
+            raise ValueError(
+                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                " the batch size of `prompt`."
+            )
+
+        tes_text_embs = []
+        tes_uncond_embs = []
+        tes_real_uncond_embs = []
+
+        for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
+            token_replacer = self.get_token_replacer(tokenizer)
+
+            # use last text_pool, because it is from text encoder 2
+            (
+                text_embeddings,
+                text_pool,
+                uncond_embeddings,
+                uncond_pool,
+                _,
+            ) = self.get_weighted_text_embeddings(
+                tokenizer,
+                text_encoder,
+                prompt=prompt,
+                uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+                max_embeddings_multiples=max_embeddings_multiples,
+                clip_skip=self.clip_skip,
+                token_replacer=token_replacer,
+                device=self.device,
+                **kwargs,
+            )
+            tes_text_embs.append(text_embeddings)
+            tes_uncond_embs.append(uncond_embeddings)
+
+            if negative_scale is not None:
+                _, real_uncond_embeddings, _ = self.get_weighted_text_embeddings(
+                    token_replacer,
+                    prompt=prompt,  # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
+                    uncond_prompt=[""] * batch_size,
+                    max_embeddings_multiples=max_embeddings_multiples,
+                    clip_skip=self.clip_skip,
+                    token_replacer=token_replacer,
+                    device=self.device,
+                    **kwargs,
+                )
+                tes_real_uncond_embs.append(real_uncond_embeddings)
+
+        # concat text encoder outputs
+        text_embeddings = tes_text_embs[0]
+        uncond_embeddings = tes_uncond_embs[0]
+        for i in range(1, len(tes_text_embs)):
+            text_embeddings = torch.cat(
+                [text_embeddings, tes_text_embs[i]], dim=2
+            )  # n,77,2048
+            if do_classifier_free_guidance:
+                uncond_embeddings = torch.cat(
+                    [uncond_embeddings, tes_uncond_embs[i]], dim=2
+                )  # n,77,2048
+
+        if do_classifier_free_guidance:
+            if negative_scale is None:
+                text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+            else:
+                text_embeddings = torch.cat(
+                    [uncond_embeddings, text_embeddings, real_uncond_embeddings]
+                )
+
+        if self.control_nets:
+            # ControlNetのhintにguide imageを流用する
+            if isinstance(clip_guide_images, PIL.Image.Image):
+                clip_guide_images = [clip_guide_images]
+            if isinstance(clip_guide_images[0], PIL.Image.Image):
+                clip_guide_images = [
+                    self.preprocess_image(im) for im in clip_guide_images
+                ]
+                clip_guide_images = torch.cat(clip_guide_images)
+            if isinstance(clip_guide_images, list):
+                clip_guide_images = torch.stack(clip_guide_images)
+
+            clip_guide_images = clip_guide_images.to(
+                self.device, dtype=text_embeddings.dtype
+            )
+
+        # create size embs
+        if original_height is None:
+            original_height = height
+        if original_width is None:
+            original_width = width
+        if original_height_negative is None:
+            original_height_negative = original_height
+        if original_width_negative is None:
+            original_width_negative = original_width
+        if crop_top is None:
+            crop_top = 0
+        if crop_left is None:
+            crop_left = 0
+        emb1 = sdxl_train_util.get_timestep_embedding(
+            torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256
+        )
+        uc_emb1 = sdxl_train_util.get_timestep_embedding(
+            torch.FloatTensor(
+                [original_height_negative, original_width_negative]
+            ).unsqueeze(0),
+            256,
+        )
+        emb2 = sdxl_train_util.get_timestep_embedding(
+            torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256
+        )
+        emb3 = sdxl_train_util.get_timestep_embedding(
+            torch.FloatTensor([height, width]).unsqueeze(0), 256
+        )
+        c_vector = (
+            torch.cat([emb1, emb2, emb3], dim=1)
+            .to(self.device, dtype=text_embeddings.dtype)
+            .repeat(batch_size, 1)
+        )
+        uc_vector = (
+            torch.cat([uc_emb1, emb2, emb3], dim=1)
+            .to(self.device, dtype=text_embeddings.dtype)
+            .repeat(batch_size, 1)
+        )
+
+        if reginonal_network:
+            # use last pool for conditioning
+            num_sub_prompts = len(text_pool) // batch_size
+            text_pool = text_pool[
+                num_sub_prompts - 1 :: num_sub_prompts
+            ]  # last subprompt
+
+        if init_image is not None and self.clip_vision_model is not None:
+            print(
+                f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}"
+            )
+            vision_input = self.clip_vision_processor(
+                init_image, return_tensors="pt", device=self.device
+            )
+            pixel_values = vision_input["pixel_values"].to(
+                self.device, dtype=text_embeddings.dtype
+            )
+
+            clip_vision_embeddings = self.clip_vision_model(
+                pixel_values=pixel_values, output_hidden_states=True, return_dict=True
+            )
+            clip_vision_embeddings = clip_vision_embeddings.image_embeds
+
+            if len(clip_vision_embeddings) == 1 and batch_size > 1:
+                clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1))
+
+            clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength
+            assert (
+                clip_vision_embeddings.shape == text_pool.shape
+            ), f"{clip_vision_embeddings.shape} != {text_pool.shape}"
+            text_pool = clip_vision_embeddings  # replace: same as ComfyUI (?)
+
+        c_vector = torch.cat([text_pool, c_vector], dim=1)
+        if do_classifier_free_guidance:
+            uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
+            vector_embeddings = torch.cat([uc_vector, c_vector])
+        else:
+            vector_embeddings = c_vector
+
+        # set timesteps
+        self.scheduler.set_timesteps(num_inference_steps, self.device)
+
+        latents_dtype = text_embeddings.dtype
+        init_latents_orig = None
+        mask = None
+
+        if init_image is None:
+            # get the initial random noise unless the user supplied it
+
+            # Unlike in other pipelines, latents need to be generated in the target device
+            # for 1-to-1 results reproducibility with the CompVis implementation.
+            # However this currently doesn't work in `mps`.
+            latents_shape = (
+                batch_size * num_images_per_prompt,
+                self.unet.in_channels,
+                height // 8,
+                width // 8,
+            )
+
+            if latents is None:
+                if self.device.type == "mps":
+                    # randn does not exist on mps
+                    latents = torch.randn(
+                        latents_shape,
+                        generator=generator,
+                        device="cpu",
+                        dtype=latents_dtype,
+                    ).to(self.device)
+                else:
+                    latents = torch.randn(
+                        latents_shape,
+                        generator=generator,
+                        device=self.device,
+                        dtype=latents_dtype,
+                    )
+            else:
+                if latents.shape != latents_shape:
+                    raise ValueError(
+                        f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
+                    )
+                latents = latents.to(self.device)
+
+            timesteps = self.scheduler.timesteps.to(self.device)
+
+            # scale the initial noise by the standard deviation required by the scheduler
+            latents = latents * self.scheduler.init_noise_sigma
+        else:
+            # image to tensor
+            if isinstance(init_image, PIL.Image.Image):
+                init_image = [init_image]
+            if isinstance(init_image[0], PIL.Image.Image):
+                init_image = [self.preprocess_image(im) for im in init_image]
+                init_image = torch.cat(init_image)
+            if isinstance(init_image, list):
+                init_image = torch.stack(init_image)
+
+            # mask image to tensor
+            if mask_image is not None:
+                if isinstance(mask_image, PIL.Image.Image):
+                    mask_image = [mask_image]
+                if isinstance(mask_image[0], PIL.Image.Image):
+                    mask_image = torch.cat(
+                        [self.preprocess_mask(im) for im in mask_image]
+                    )  # H*W, 0 for repaint
+
+            # encode the init image into latents and scale the latents
+            init_image = init_image.to(device=self.device, dtype=latents_dtype)
+            if init_image.size()[-2:] == (height // 8, width // 8):
+                init_latents = init_image
+            else:
+                if vae_batch_size >= batch_size:
+                    init_latent_dist = self.vae.encode(
+                        init_image.to(self.vae.dtype)
+                    ).latent_dist
+                    init_latents = init_latent_dist.sample(generator=generator)
+                else:
+                    if torch.cuda.is_available():
+                        torch.cuda.empty_cache()
+                    init_latents = []
+                    for i in tqdm(
+                        range(0, min(batch_size, len(init_image)), vae_batch_size)
+                    ):
+                        init_latent_dist = self.vae.encode(
+                            (
+                                init_image[i : i + vae_batch_size]
+                                if vae_batch_size > 1
+                                else init_image[i].unsqueeze(0)
+                            ).to(self.vae.dtype)
+                        ).latent_dist
+                        init_latents.append(
+                            init_latent_dist.sample(generator=generator)
+                        )
+                    init_latents = torch.cat(init_latents)
+
+                init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
+
+            if len(init_latents) == 1:
+                init_latents = init_latents.repeat((batch_size, 1, 1, 1))
+            init_latents_orig = init_latents
+
+            # preprocess mask
+            if mask_image is not None:
+                mask = mask_image.to(device=self.device, dtype=latents_dtype)
+                if len(mask) == 1:
+                    mask = mask.repeat((batch_size, 1, 1, 1))
+
+                # check sizes
+                if not mask.shape == init_latents.shape:
+                    raise ValueError("The mask and init_image should be the same size!")
+
+            # get the original timestep using init_timestep
+            offset = self.scheduler.config.get("steps_offset", 0)
+            init_timestep = int(num_inference_steps * strength) + offset
+            init_timestep = min(init_timestep, num_inference_steps)
+
+            timesteps = self.scheduler.timesteps[-init_timestep]
+            timesteps = torch.tensor(
+                [timesteps] * batch_size * num_images_per_prompt, device=self.device
+            )
+
+            # add noise to latents using the timesteps
+            latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps)
+
+            t_start = max(num_inference_steps - init_timestep + offset, 0)
+            timesteps = self.scheduler.timesteps[t_start:].to(self.device)
+
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+        accepts_eta = "eta" in set(
+            inspect.signature(self.scheduler.step).parameters.keys()
+        )
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        num_latent_input = (
+            (3 if negative_scale is not None else 2)
+            if do_classifier_free_guidance
+            else 1
+        )
+
+        if self.control_nets:
+            # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
+            if self.control_net_enabled:
+                for control_net, _ in self.control_nets:
+                    with torch.no_grad():
+                        control_net.set_cond_image(clip_guide_images)
+            else:
+                for control_net, _ in self.control_nets:
+                    control_net.set_cond_image(None)
+
+        each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
+        for i, t in enumerate(tqdm(timesteps)):
+            # expand the latents if we are doing classifier free guidance
+            latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
+            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+            # disable control net if ratio is set
+            if self.control_nets and self.control_net_enabled:
+                for j, ((control_net, ratio), enabled) in enumerate(
+                    zip(self.control_nets, each_control_net_enabled)
+                ):
+                    if not enabled or ratio >= 1.0:
+                        continue
+                    if ratio < i / len(timesteps):
+                        print(
+                            f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})"
+                        )
+                        control_net.set_cond_image(None)
+                        each_control_net_enabled[j] = False
+
+            # predict the noise residual
+            # TODO Diffusers' ControlNet
+            # if self.control_nets and self.control_net_enabled:
+            #     if reginonal_network:
+            #         num_sub_and_neg_prompts = len(text_embeddings) // batch_size
+            #         text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts]  # last subprompt
+            #     else:
+            #         text_emb_last = text_embeddings
+
+            #     # not working yet
+            #     noise_pred = original_control_net.call_unet_and_control_net(
+            #         i,
+            #         num_latent_input,
+            #         self.unet,
+            #         self.control_nets,
+            #         guided_hints,
+            #         i / len(timesteps),
+            #         latent_model_input,
+            #         t,
+            #         text_emb_last,
+            #     ).sample
+            # else:
+            noise_pred = self.unet(
+                latent_model_input, t, text_embeddings, vector_embeddings
+            )
+
+            # perform guidance
+            if do_classifier_free_guidance:
+                if negative_scale is None:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(
+                        num_latent_input
+                    )  # uncond by negative prompt
+                    noise_pred = noise_pred_uncond + guidance_scale * (
+                        noise_pred_text - noise_pred_uncond
+                    )
+                else:
+                    (
+                        noise_pred_negative,
+                        noise_pred_text,
+                        noise_pred_uncond,
+                    ) = noise_pred.chunk(
+                        num_latent_input
+                    )  # uncond is real uncond
+                    noise_pred = (
+                        noise_pred_uncond
+                        + guidance_scale * (noise_pred_text - noise_pred_uncond)
+                        - negative_scale * (noise_pred_negative - noise_pred_uncond)
+                    )
+
+            # compute the previous noisy sample x_t -> x_t-1
+            latents = self.scheduler.step(
+                noise_pred, t, latents, **extra_step_kwargs
+            ).prev_sample
+
+            if mask is not None:
+                # masking
+                init_latents_proper = self.scheduler.add_noise(
+                    init_latents_orig, img2img_noise, torch.tensor([t])
+                )
+                latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+            # call the callback, if provided
+            if i % callback_steps == 0:
+                if callback is not None:
+                    callback(i, t, latents)
+                if is_cancelled_callback is not None and is_cancelled_callback():
+                    return None
+
+        if return_latents:
+            return latents
+
+        latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
+        if vae_batch_size >= batch_size:
+            image = self.vae.decode(latents.to(self.vae.dtype)).sample
+        else:
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+            images = []
+            for i in tqdm(range(0, batch_size, vae_batch_size)):
+                images.append(
+                    self.vae.decode(
+                        (
+                            latents[i : i + vae_batch_size]
+                            if vae_batch_size > 1
+                            else latents[i].unsqueeze(0)
+                        ).to(self.vae.dtype)
+                    ).sample
+                )
+            image = torch.cat(images)
+
+        image = (image / 2 + 0.5).clamp(0, 1)
+
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+        if output_type == "pil":
+            # image = self.numpy_to_pil(image)
+            image = (image * 255).round().astype("uint8")
+            image = [Image.fromarray(im) for im in image]
+
+        return image
+
+
+class SDXLLLiteImg2ImgPipeline:
+    def __init__(self):
+        self.SCHEDULER_LINEAR_START = 0.00085
+        self.SCHEDULER_LINEAR_END = 0.0120
+        self.SCHEDULER_TIMESTEPS = 1000
+        self.SCHEDLER_SCHEDULE = "scaled_linear"
+        self.LATENT_CHANNELS = 4
+        self.DOWNSAMPLING_FACTOR = 8
+
+    def replace_unet_modules(
+        self,
+        unet: diffusers.models.unet_2d_condition.UNet2DConditionModel,
+        mem_eff_attn,
+        xformers,
+        sdpa,
+    ):
+        if mem_eff_attn:
+            print("Enable memory efficient attention for U-Net")
+
+            # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
+            unet.set_use_memory_efficient_attention(False, True)
+        elif xformers:
+            print("Enable xformers for U-Net")
+            try:
+                import xformers.ops
+            except ImportError:
+                raise ImportError("No xformers / xformersがインストールされていないようです")
+
+            unet.set_use_memory_efficient_attention(True, False)
+        elif sdpa:
+            print("Enable SDPA for U-Net")
+            unet.set_use_memory_efficient_attention(False, False)
+            unet.set_use_sdpa(True)
+
+    # TODO common train_util.py
+    def replace_vae_modules(
+        self, vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa
+    ):
+        if mem_eff_attn:
+            self.replace_vae_attn_to_memory_efficient()
+        elif xformers:
+            # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す?
+            vae.set_use_memory_efficient_attention_xformers(True)  # とりあえずこっちを使う
+        elif sdpa:
+            self.replace_vae_attn_to_sdpa()
+
+    def replace_vae_attn_to_memory_efficient(self):
+        print(
+            "VAE Attention.forward has been replaced to FlashAttention (not xformers)"
+        )
+        flash_func = FlashAttentionFunction
+
+        def forward_flash_attn(self, hidden_states, **kwargs):
+            q_bucket_size = 512
+            k_bucket_size = 1024
+
+            residual = hidden_states
+            batch, channel, height, width = hidden_states.shape
+
+            # norm
+            hidden_states = self.group_norm(hidden_states)
+
+            hidden_states = hidden_states.view(
+                batch, channel, height * width
+            ).transpose(1, 2)
+
+            # proj to q, k, v
+            query_proj = self.to_q(hidden_states)
+            key_proj = self.to_k(hidden_states)
+            value_proj = self.to_v(hidden_states)
+
+            query_proj, key_proj, value_proj = map(
+                lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads),
+                (query_proj, key_proj, value_proj),
+            )
+
+            out = flash_func.apply(
+                query_proj,
+                key_proj,
+                value_proj,
+                None,
+                False,
+                q_bucket_size,
+                k_bucket_size,
+            )
+
+            out = rearrange(out, "b h n d -> b n (h d)")
+
+            # compute next hidden_states
+            # linear proj
+            hidden_states = self.to_out[0](hidden_states)
+            # dropout
+            hidden_states = self.to_out[1](hidden_states)
+
+            hidden_states = hidden_states.transpose(-1, -2).reshape(
+                batch, channel, height, width
+            )
+
+            # res connect and rescale
+            hidden_states = (hidden_states + residual) / self.rescale_output_factor
+            return hidden_states
+
+        def forward_flash_attn_0_14(self, hidden_states, **kwargs):
+            if not hasattr(self, "to_q"):
+                self.to_q = self.query
+                self.to_k = self.key
+                self.to_v = self.value
+                self.to_out = [self.proj_attn, torch.nn.Identity()]
+                self.heads = self.num_heads
+            return forward_flash_attn(self, hidden_states, **kwargs)
+
+        if diffusers.__version__ < "0.15.0":
+            diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
+        else:
+            diffusers.models.attention_processor.Attention.forward = forward_flash_attn
+
+    def replace_vae_attn_to_xformers(self):
+        print("VAE: Attention.forward has been replaced to xformers")
+        import xformers.ops
+
+        def forward_xformers(self, hidden_states, **kwargs):
+            residual = hidden_states
+            batch, channel, height, width = hidden_states.shape
+
+            # norm
+            hidden_states = self.group_norm(hidden_states)
+
+            hidden_states = hidden_states.view(
+                batch, channel, height * width
+            ).transpose(1, 2)
+
+            # proj to q, k, v
+            query_proj = self.to_q(hidden_states)
+            key_proj = self.to_k(hidden_states)
+            value_proj = self.to_v(hidden_states)
+
+            query_proj, key_proj, value_proj = map(
+                lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads),
+                (query_proj, key_proj, value_proj),
+            )
+
+            query_proj = query_proj.contiguous()
+            key_proj = key_proj.contiguous()
+            value_proj = value_proj.contiguous()
+            out = xformers.ops.memory_efficient_attention(
+                query_proj, key_proj, value_proj, attn_bias=None
+            )
+
+            out = rearrange(out, "b h n d -> b n (h d)")
+
+            # compute next hidden_states
+            # linear proj
+            hidden_states = self.to_out[0](hidden_states)
+            # dropout
+            hidden_states = self.to_out[1](hidden_states)
+
+            hidden_states = hidden_states.transpose(-1, -2).reshape(
+                batch, channel, height, width
+            )
+
+            # res connect and rescale
+            hidden_states = (hidden_states + residual) / self.rescale_output_factor
+            return hidden_states
+
+        def forward_xformers_0_14(self, hidden_states, **kwargs):
+            if not hasattr(self, "to_q"):
+                self.to_q = self.query
+                self.to_k = self.key
+                self.to_v = self.value
+                self.to_out = [self.proj_attn, torch.nn.Identity()]
+                self.heads = self.num_heads
+            return forward_xformers(self, hidden_states, **kwargs)
+
+        if diffusers.__version__ < "0.15.0":
+            diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
+        else:
+            diffusers.models.attention_processor.Attention.forward = forward_xformers
+
+    def replace_vae_attn_to_sdpa():
+        print("VAE: Attention.forward has been replaced to sdpa")
+
+        def forward_sdpa(self, hidden_states, **kwargs):
+            residual = hidden_states
+            batch, channel, height, width = hidden_states.shape
+
+            # norm
+            hidden_states = self.group_norm(hidden_states)
+
+            hidden_states = hidden_states.view(
+                batch, channel, height * width
+            ).transpose(1, 2)
+
+            # proj to q, k, v
+            query_proj = self.to_q(hidden_states)
+            key_proj = self.to_k(hidden_states)
+            value_proj = self.to_v(hidden_states)
+
+            query_proj, key_proj, value_proj = map(
+                lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads),
+                (query_proj, key_proj, value_proj),
+            )
+
+            out = torch.nn.functional.scaled_dot_product_attention(
+                query_proj,
+                key_proj,
+                value_proj,
+                attn_mask=None,
+                dropout_p=0.0,
+                is_causal=False,
+            )
+
+            out = rearrange(out, "b n h d -> b n (h d)")
+
+            # compute next hidden_states
+            # linear proj
+            hidden_states = self.to_out[0](hidden_states)
+            # dropout
+            hidden_states = self.to_out[1](hidden_states)
+
+            hidden_states = hidden_states.transpose(-1, -2).reshape(
+                batch, channel, height, width
+            )
+
+            # res connect and rescale
+            hidden_states = (hidden_states + residual) / self.rescale_output_factor
+            return hidden_states
+
+        def forward_sdpa_0_14(self, hidden_states, **kwargs):
+            if not hasattr(self, "to_q"):
+                self.to_q = self.query
+                self.to_k = self.key
+                self.to_v = self.value
+                self.to_out = [self.proj_attn, torch.nn.Identity()]
+                self.heads = self.num_heads
+            return forward_sdpa(self, hidden_states, **kwargs)
+
+        if diffusers.__version__ < "0.15.0":
+            diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14
+        else:
+            diffusers.models.attention_processor.Attention.forward = forward_sdpa
+
+    def load(self, pipeline: AbstractPipeline, controlnet_urls: Optional[List[str]]):
+        pipeline.pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
+        pipeline.pipe.fuse_lora()
+
+        self.dtype = pipeline.pipe.dtype
+        self.device = pipeline.pipe.device
+        state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(
+            pipeline.pipe.unet.state_dict()
+        )
+        with init_empty_weights():
+            original_unet = (
+                sdxl_original_unet.SdxlUNet2DConditionModel()
+            )  # overwrite unet
+        sdxl_model_util._load_state_dict_on_device(
+            original_unet,
+            state_dict,
+            device=pipeline.pipe.device,
+            dtype=pipeline.pipe.dtype,
+        )
+        unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(
+            original_unet
+        )
+        sched_init_args = {}
+        has_steps_offset = True
+        has_clip_sample = True
+        scheduler_num_noises_per_step = 1
+
+        mem_eff = not (True or False)
+        self.replace_unet_modules(unet, mem_eff, True, False)
+        self.replace_vae_modules(pipeline.pipe.vae, mem_eff, True, False)
+
+        scheduler_cls = LCMScheduler
+        scheduler_module = diffusers.schedulers.scheduling_ddim
+
+        if has_steps_offset:
+            sched_init_args["steps_offset"] = 1
+        if has_clip_sample:
+            sched_init_args["clip_sample"] = False
+
+        class NoiseManager:
+            def __init__(self):
+                self.sampler_noises = None
+                self.sampler_noise_index = 0
+
+            def reset_sampler_noises(self, noises):
+                self.sampler_noise_index = 0
+                self.sampler_noises = noises
+
+            def randn(
+                self, shape, device=None, dtype=None, layout=None, generator=None
+            ):
+                # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
+                if self.sampler_noises is not None and self.sampler_noise_index < len(
+                    self.sampler_noises
+                ):
+                    noise = self.sampler_noises[self.sampler_noise_index]
+                    if shape != noise.shape:
+                        noise = None
+                else:
+                    noise = None
+
+                if noise == None:
+                    print(
+                        f"unexpected noise request: {self.sampler_noise_index}, {shape}"
+                    )
+                    noise = torch.randn(
+                        shape, dtype=dtype, device=device, generator=generator
+                    )
+
+                self.sampler_noise_index += 1
+                return noise
+
+        class TorchRandReplacer:
+            def __init__(self, noise_manager):
+                self.noise_manager = noise_manager
+
+            def __getattr__(self, item):
+                if item == "randn":
+                    return self.noise_manager.randn
+                if hasattr(torch, item):
+                    return getattr(torch, item)
+                raise AttributeError(
+                    "'{}' object has no attribute '{}'".format(
+                        type(self).__name__, item
+                    )
+                )
+
+        noise_manager = NoiseManager()
+        if scheduler_module is not None:
+            scheduler_module.torch = TorchRandReplacer(noise_manager)
+
+        scheduler = scheduler_cls(
+            num_train_timesteps=self.SCHEDULER_TIMESTEPS,
+            beta_start=self.SCHEDULER_LINEAR_START,
+            beta_end=self.SCHEDULER_LINEAR_END,
+            beta_schedule=self.SCHEDLER_SCHEDULE,
+            **sched_init_args,
+        )
+        device = torch.device(
+            pipeline.pipe.device if torch.cuda.is_available() else "cpu"
+        )
+        # vae.to(vae_dtype).to(device)
+        # vae.eval()
+        # text_encoder1.to(dtype).to(device)
+        # text_encoder2.to(dtype).to(device)
+        print(pipeline.pipe.dtype)
+        unet.to(pipeline.pipe.dtype).to(pipeline.pipe.device)
+        # text_encoder1.eval()
+        # text_encoder2.eval()
+        unet.eval()
+        control_nets: List[Tuple[ControlNetLLLite, float]] = []
+        for link in controlnet_urls:
+            net_path = Path.home() / ".cache" / link.split("/")[-1]
+            download_file(link, net_path)
+            print(f"loading controlnet {net_path}")
+            state_dict = load_file(net_path)
+            mlp_dim = None
+            cond_emb_dim = None
+            for key, value in state_dict.items():
+                if mlp_dim is None and "down.0.weight" in key:
+                    mlp_dim = value.shape[0]
+                elif cond_emb_dim is None and "conditioning1.0" in key:
+                    cond_emb_dim = value.shape[0] * 2
+                if mlp_dim is not None and cond_emb_dim is not None:
+                    break
+            assert (
+                mlp_dim is not None and cond_emb_dim is not None
+            ), f"invalid control net: {link}"
+
+            multiplier = 0.2
+            # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
+            ratio = 1.0
+
+            control_net = ControlNetLLLite(
+                unet, cond_emb_dim, mlp_dim, multiplier=multiplier
+            )
+            control_net.apply_to()
+            control_net.load_state_dict(state_dict)
+            control_net.to(pipeline.pipe.dtype).to(device)
+            control_net.set_batch_cond_only(False, False)
+            control_nets.append((control_net, ratio))
+
+        networks = []
+        self.pipe = PipelineLike(
+            device,
+            pipeline.pipe.vae,
+            [pipeline.pipe.text_encoder, pipeline.pipe.text_encoder_2],
+            [pipeline.pipe.tokenizer, pipeline.pipe.tokenizer_2],
+            unet,
+            scheduler,
+            2,
+        )
+        self.pipe.set_control_nets(control_nets)
+
+        clear_cuda_and_gc()
+
+        pipeline.pipe.unload_lora_weights()
+        pipeline.pipe.unfuse_lora()
+
+        clear_cuda_and_gc()
+
+    def __call__(
+        self,
+        prompt: str,
+        negative_prompt: str,
+        seed: int,
+        image: Image.Image,
+        condition_image: Union[Image.Image, List[Image.Image]],
+        height: int = 1024,
+        width: int = 1024,
+        num_inference_steps: int = 24,
+        guidance_scale=1.0,
+    ):
+        noise_shape = (
+            self.LATENT_CHANNELS,
+            height // self.DOWNSAMPLING_FACTOR,
+            width // self.DOWNSAMPLING_FACTOR,
+        )
+        i2i_noises = torch.zeros(
+            (1, *noise_shape), device=self.device, dtype=self.dtype
+        )
+        i2i_noises[0] = torch.randn(noise_shape, device=self.device, dtype=self.dtype)
+        images = self.pipe(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            seed=seed,
+            init_image=image,
+            height=height,
+            width=width,
+            strength=1.0,
+            num_inference_steps=num_inference_steps,
+            guidance_scale=guidance_scale,
+            clip_guide_images=condition_image,
+            img2img_noise=i2i_noises,
+        )
+        return images
diff --git a/requirements.txt b/requirements.txt
index 4b20b72849abc09355df1136b0c41637890b05d8..b73f6642fe268bc499a96760c2fff6922b88a369 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -43,3 +43,4 @@ onnx
 onnxruntime-gpu
 imgaug==0.4.0
 tqdm==4.64.1
+toml