diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f799589bc9555d123a833ce2a704e6a8b011845a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+__pycache__/
+.venv
+venv/
+logs/
diff --git a/cache_latents.py b/cache_latents.py
new file mode 100644
index 0000000000000000000000000000000000000000..61446a768082885e13df26cbb7b8cc0ad649c116
--- /dev/null
+++ b/cache_latents.py
@@ -0,0 +1,245 @@
+import argparse
+import os
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from PIL import Image
+
+import logging
+
+from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache
+from hunyuan_model.vae import load_vae
+from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from utils.model_utils import str_to_dtype
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
+ import cv2
+
+ imgs = (
+ [image]
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
+ else [image[0], image[-1]]
+ )
+ if len(imgs) > 1:
+ print(f"Number of images: {len(image)}")
+ for i, img in enumerate(imgs):
+ if len(imgs) > 1:
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
+ else:
+ print(f"Image: {img.shape}")
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
+ cv2.imshow("image", cv2_img)
+ k = cv2.waitKey(0)
+ cv2.destroyAllWindows()
+ if k == ord("q") or k == ord("d"):
+ return k
+ return k
+
+
+def show_console(
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
+ width: int,
+ back: str,
+ interactive: bool = False,
+) -> int:
+ from ascii_magic import from_pillow_image, Back
+
+ back = None
+ if back is not None:
+ back = getattr(Back, back.upper())
+
+ k = None
+ imgs = (
+ [image]
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
+ else [image[0], image[-1]]
+ )
+ if len(imgs) > 1:
+ print(f"Number of images: {len(image)}")
+ for i, img in enumerate(imgs):
+ if len(imgs) > 1:
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
+ else:
+ print(f"Image: {img.shape}")
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
+ ascii_img = from_pillow_image(pil_img)
+ ascii_img.to_terminal(columns=width, back=back)
+
+ if interactive:
+ k = input("Press q to quit, d to next dataset, other key to next: ")
+ if k == "q" or k == "d":
+ return ord(k)
+
+ if not interactive:
+ return ord(" ")
+ return ord(k) if k else ord(" ")
+
+
+def show_datasets(
+ datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
+):
+ print(f"d: next dataset, q: quit")
+
+ num_workers = max(1, os.cpu_count() - 1)
+ for i, dataset in enumerate(datasets):
+ print(f"Dataset [{i}]")
+ batch_index = 0
+ num_images_to_show = console_num_images
+ k = None
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
+ print(f"bucket resolution: {key}, count: {len(batch)}")
+ for j, item_info in enumerate(batch):
+ item_info: ItemInfo
+ print(f"{batch_index}-{j}: {item_info}")
+ if debug_mode == "image":
+ k = show_image(item_info.content)
+ elif debug_mode == "console":
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
+ if num_images_to_show is not None:
+ num_images_to_show -= 1
+ if num_images_to_show == 0:
+ k = ord("d") # next dataset
+
+ if k == ord("q"):
+ return
+ elif k == ord("d"):
+ break
+ if k == ord("d"):
+ break
+ batch_index += 1
+
+
+def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
+ if len(contents.shape) == 4:
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
+
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
+ contents = contents.to(vae.device, dtype=vae.dtype)
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
+
+ # print(f"encode batch: {contents.shape}")
+ with torch.no_grad():
+ latent = vae.encode(contents).latent_dist.sample()
+ latent = latent * vae.config.scaling_factor
+
+ # # debug: decode and save
+ # with torch.no_grad():
+ # latent_to_decode = latent / vae.config.scaling_factor
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
+ # images = (images / 2 + 0.5).clamp(0, 1)
+ # images = images.cpu().float().numpy()
+ # images = (images * 255).astype(np.uint8)
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
+ # for b in range(images.shape[0]):
+ # for f in range(images.shape[1]):
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
+ # img = Image.fromarray(images[b, f])
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
+
+ for item, l in zip(batch, latent):
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
+ save_latent_cache(item, l)
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ if args.debug_mode is not None:
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
+ return
+
+ assert args.vae is not None, "vae checkpoint is required"
+
+ # Load VAE model: HunyuanVideo VAE model is float16
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
+ vae.eval()
+ print(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
+
+ if args.vae_chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ elif args.vae_tiling:
+ vae.enable_spatial_tiling(True)
+
+ # Encode images
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
+ for i, dataset in enumerate(datasets):
+ print(f"Encoding dataset [{i}]")
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
+ if args.skip_existing:
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
+ if len(filtered_batch) == 0:
+ continue
+ batch = filtered_batch
+
+ bs = args.batch_size if args.batch_size is not None else len(batch)
+ for i in range(0, len(batch), bs):
+ encode_and_save_batch(vae, batch[i : i + bs])
+
+
+def setup_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
+ parser.add_argument(
+ "--vae_tiling",
+ action="store_true",
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
+ parser.add_argument(
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
+ )
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
+ parser.add_argument(
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
+ )
+ parser.add_argument(
+ "--console_num_images",
+ type=int,
+ default=None,
+ help="debug mode: not interactive, number of images to show for each dataset",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+
+ args = parser.parse_args()
+ main(args)
diff --git a/cache_text_encoder_outputs.py b/cache_text_encoder_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c68decaf235f1dd8c16b2ac8f1b9d5fef6c41c
--- /dev/null
+++ b/cache_text_encoder_outputs.py
@@ -0,0 +1,135 @@
+import argparse
+import os
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+import accelerate
+
+from dataset.image_video_dataset import ItemInfo, save_text_encoder_output_cache
+from hunyuan_model import text_encoder as text_encoder_module
+from hunyuan_model.text_encoder import TextEncoder
+
+import logging
+
+from utils.model_utils import str_to_dtype
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
+ data_type = "video" # video only, image is not supported
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
+
+ with torch.no_grad():
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
+
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
+
+
+def encode_and_save_batch(
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
+):
+ prompts = [item.caption for item in batch]
+ # print(prompts)
+
+ # encode prompt
+ if accelerator is not None:
+ with accelerator.autocast():
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
+ else:
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
+
+ # # convert to fp16 if needed
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
+
+ # save prompt cache
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ # define accelerator for fp8 inference
+ accelerator = None
+ if args.fp8_llm:
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
+
+ # define encode function
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
+
+ def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool):
+ for i, dataset in enumerate(datasets):
+ print(f"Encoding dataset [{i}]")
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
+ if args.skip_existing:
+ filtered_batch = [item for item in batch if not os.path.exists(item.text_encoder_output_cache_path)]
+ if len(filtered_batch) == 0:
+ continue
+ batch = filtered_batch
+
+ bs = args.batch_size if args.batch_size is not None else len(batch)
+ for i in range(0, len(batch), bs):
+ encode_and_save_batch(text_encoder, batch[i : i + bs], is_llm, accelerator)
+
+ # Load Text Encoder 1
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
+ text_encoder_1.to(device=device)
+
+ # Encode with Text Encoder 1
+ logger.info("Encoding with Text Encoder 1")
+ encode_for_text_encoder(text_encoder_1, is_llm=True)
+ del text_encoder_1
+
+ # Load Text Encoder 2
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
+ text_encoder_2.to(device=device)
+
+ # Encode with Text Encoder 2
+ logger.info("Encoding with Text Encoder 2")
+ encode_for_text_encoder(text_encoder_2, is_llm=False)
+ del text_encoder_2
+
+
+def setup_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
+ parser.add_argument(
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
+ )
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+
+ args = parser.parse_args()
+ main(args)
diff --git a/convert_lora.py b/convert_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..c49ad482969e1de224bdcff3fc67a535692371b8
--- /dev/null
+++ b/convert_lora.py
@@ -0,0 +1,129 @@
+import argparse
+
+import torch
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from utils import model_utils
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def convert_from_diffusers(prefix, weights_sd):
+ # convert from diffusers(?) to default LoRA
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
+ # note: Diffusers has no alpha, so alpha is set to rank
+ new_weights_sd = {}
+ lora_dims = {}
+ for key, weight in weights_sd.items():
+ diffusers_prefix, key_body = key.split(".", 1)
+ if diffusers_prefix != "diffusion_model":
+ logger.warning(f"unexpected key: {key} in diffusers format")
+ continue
+
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
+ new_weights_sd[new_key] = weight
+
+ lora_name = new_key.split(".")[0] # before first dot
+ if lora_name not in lora_dims and "lora_down" in new_key:
+ lora_dims[lora_name] = weight.shape[0]
+
+ # add alpha with rank
+ for lora_name, dim in lora_dims.items():
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
+
+ return new_weights_sd
+
+
+def convert_to_diffusers(prefix, weights_sd):
+ # convert from default LoRA to diffusers
+
+ # get alphas
+ lora_alphas = {}
+ for key, weight in weights_sd.items():
+ if key.startswith(prefix):
+ lora_name = key.split(".", 1)[0] # before first dot
+ if lora_name not in lora_alphas and "alpha" in key:
+ lora_alphas[lora_name] = weight
+
+ new_weights_sd = {}
+ for key, weight in weights_sd.items():
+ if key.startswith(prefix):
+ if "alpha" in key:
+ continue
+
+ lora_name = key.split(".", 1)[0] # before first dot
+
+ # HunyuanVideo lora name to module name: ugly but works
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
+ module_name = module_name.replace("_", ".") # replace "_" with "."
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
+ module_name = module_name.replace("img.", "img_") # fix img
+ module_name = module_name.replace("txt.", "txt_") # fix txt
+ module_name = module_name.replace("attn.", "attn_") # fix attn
+
+ diffusers_prefix = "diffusion_model"
+ if "lora_down" in key:
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
+ dim = weight.shape[0]
+ elif "lora_up" in key:
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
+ dim = weight.shape[1]
+ else:
+ logger.warning(f"unexpected key: {key} in default LoRA format")
+ continue
+
+ # scale weight by alpha
+ if lora_name in lora_alphas:
+ # we scale both down and up, so scale is sqrt
+ scale = lora_alphas[lora_name] / dim
+ scale = scale.sqrt()
+ weight = weight * scale
+ else:
+ logger.warning(f"missing alpha for {lora_name}")
+
+ new_weights_sd[new_key] = weight
+
+ return new_weights_sd
+
+
+def convert(input_file, output_file, target_format):
+ logger.info(f"loading {input_file}")
+ weights_sd = load_file(input_file)
+ with safe_open(input_file, framework="pt") as f:
+ metadata = f.metadata()
+
+ logger.info(f"converting to {target_format}")
+ prefix = "lora_unet_"
+ if target_format == "default":
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
+ metadata = metadata or {}
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
+ elif target_format == "other":
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
+ else:
+ raise ValueError(f"unknown target format: {target_format}")
+
+ logger.info(f"saving to {output_file}")
+ save_file(new_weights_sd, output_file, metadata=metadata)
+
+ logger.info("done")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
+ parser.add_argument("--input", type=str, required=True, help="input model file")
+ parser.add_argument("--output", type=str, required=True, help="output model file")
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ convert(args.input, args.output, args.target)
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dataset/config_utils.py b/dataset/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c00ffb10c359fbcabc58bc6b464f7e08e99ae3f
--- /dev/null
+++ b/dataset/config_utils.py
@@ -0,0 +1,359 @@
+import argparse
+from dataclasses import (
+ asdict,
+ dataclass,
+)
+import functools
+import random
+from textwrap import dedent, indent
+import json
+from pathlib import Path
+
+# from toolz import curry
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import toml
+import voluptuous
+from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
+
+from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+@dataclass
+class BaseDatasetParams:
+ resolution: Tuple[int, int] = (960, 544)
+ enable_bucket: bool = False
+ bucket_no_upscale: bool = False
+ caption_extension: Optional[str] = None
+ batch_size: int = 1
+ cache_directory: Optional[str] = None
+ debug_dataset: bool = False
+
+
+@dataclass
+class ImageDatasetParams(BaseDatasetParams):
+ image_directory: Optional[str] = None
+ image_jsonl_file: Optional[str] = None
+
+
+@dataclass
+class VideoDatasetParams(BaseDatasetParams):
+ video_directory: Optional[str] = None
+ video_jsonl_file: Optional[str] = None
+ target_frames: Sequence[int] = (1,)
+ frame_extraction: Optional[str] = "head"
+ frame_stride: Optional[int] = 1
+ frame_sample: Optional[int] = 1
+
+
+@dataclass
+class DatasetBlueprint:
+ is_image_dataset: bool
+ params: Union[ImageDatasetParams, VideoDatasetParams]
+
+
+@dataclass
+class DatasetGroupBlueprint:
+ datasets: Sequence[DatasetBlueprint]
+
+
+@dataclass
+class Blueprint:
+ dataset_group: DatasetGroupBlueprint
+
+
+class ConfigSanitizer:
+ # @curry
+ @staticmethod
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
+ Schema(ExactSequence([klass, klass]))(value)
+ return tuple(value)
+
+ # @curry
+ @staticmethod
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
+ try:
+ Schema(klass)(value)
+ return (value, value)
+ except:
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
+
+ # datasets schema
+ DATASET_ASCENDABLE_SCHEMA = {
+ "caption_extension": str,
+ "batch_size": int,
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
+ "enable_bucket": bool,
+ "bucket_no_upscale": bool,
+ }
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
+ "image_directory": str,
+ "image_jsonl_file": str,
+ "cache_directory": str,
+ }
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
+ "video_directory": str,
+ "video_jsonl_file": str,
+ "target_frames": [int],
+ "frame_extraction": str,
+ "frame_stride": int,
+ "frame_sample": int,
+ "cache_directory": str,
+ }
+
+ # options handled by argparse but not handled by user config
+ ARGPARSE_SPECIFIC_SCHEMA = {
+ "debug_dataset": bool,
+ }
+
+ def __init__(self) -> None:
+ self.image_dataset_schema = self.__merge_dict(
+ self.DATASET_ASCENDABLE_SCHEMA,
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
+ )
+ self.video_dataset_schema = self.__merge_dict(
+ self.DATASET_ASCENDABLE_SCHEMA,
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
+ )
+
+ def validate_flex_dataset(dataset_config: dict):
+ if "target_frames" in dataset_config:
+ return Schema(self.video_dataset_schema)(dataset_config)
+ else:
+ return Schema(self.image_dataset_schema)(dataset_config)
+
+ self.dataset_schema = validate_flex_dataset
+
+ self.general_schema = self.__merge_dict(
+ self.DATASET_ASCENDABLE_SCHEMA,
+ )
+ self.user_config_validator = Schema(
+ {
+ "general": self.general_schema,
+ "datasets": [self.dataset_schema],
+ }
+ )
+ self.argparse_schema = self.__merge_dict(
+ self.ARGPARSE_SPECIFIC_SCHEMA,
+ )
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
+
+ def sanitize_user_config(self, user_config: dict) -> dict:
+ try:
+ return self.user_config_validator(user_config)
+ except MultipleInvalid:
+ # TODO: clarify the error message
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
+ raise
+
+ # NOTE: In nature, argument parser result is not needed to be sanitize
+ # However this will help us to detect program bug
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
+ try:
+ return self.argparse_config_validator(argparse_namespace)
+ except MultipleInvalid:
+ # XXX: this should be a bug
+ logger.error(
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
+ )
+ raise
+
+ # NOTE: value would be overwritten by latter dict if there is already the same key
+ @staticmethod
+ def __merge_dict(*dict_list: dict) -> dict:
+ merged = {}
+ for schema in dict_list:
+ # merged |= schema
+ for k, v in schema.items():
+ merged[k] = v
+ return merged
+
+
+class BlueprintGenerator:
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
+
+ def __init__(self, sanitizer: ConfigSanitizer):
+ self.sanitizer = sanitizer
+
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
+
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
+ general_config = sanitized_user_config.get("general", {})
+
+ dataset_blueprints = []
+ for dataset_config in sanitized_user_config.get("datasets", []):
+ is_image_dataset = "target_frames" not in dataset_config
+ if is_image_dataset:
+ dataset_params_klass = ImageDatasetParams
+ else:
+ dataset_params_klass = VideoDatasetParams
+
+ params = self.generate_params_by_fallbacks(
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
+ )
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
+
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
+
+ return Blueprint(dataset_group_blueprint)
+
+ @staticmethod
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
+ search_value = BlueprintGenerator.search_value
+ default_params = asdict(param_klass())
+ param_names = default_params.keys()
+
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
+
+ return param_klass(**params)
+
+ @staticmethod
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
+ for cand in fallbacks:
+ value = cand.get(key)
+ if value is not None:
+ return value
+
+ return default_value
+
+
+# if training is True, it will return a dataset group for training, otherwise for caching
+def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
+
+ for dataset_blueprint in dataset_group_blueprint.datasets:
+ if dataset_blueprint.is_image_dataset:
+ dataset_klass = ImageDataset
+ else:
+ dataset_klass = VideoDataset
+
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
+ datasets.append(dataset)
+
+ # print info
+ info = ""
+ for i, dataset in enumerate(datasets):
+ is_image_dataset = isinstance(dataset, ImageDataset)
+ info += dedent(
+ f"""\
+ [Dataset {i}]
+ is_image_dataset: {is_image_dataset}
+ resolution: {dataset.resolution}
+ batch_size: {dataset.batch_size}
+ caption_extension: "{dataset.caption_extension}"
+ enable_bucket: {dataset.enable_bucket}
+ bucket_no_upscale: {dataset.bucket_no_upscale}
+ cache_directory: "{dataset.cache_directory}"
+ debug_dataset: {dataset.debug_dataset}
+ """
+ )
+
+ if is_image_dataset:
+ info += indent(
+ dedent(
+ f"""\
+ image_directory: "{dataset.image_directory}"
+ image_jsonl_file: "{dataset.image_jsonl_file}"
+ \n"""
+ ),
+ " ",
+ )
+ else:
+ info += indent(
+ dedent(
+ f"""\
+ video_directory: "{dataset.video_directory}"
+ video_jsonl_file: "{dataset.video_jsonl_file}"
+ target_frames: {dataset.target_frames}
+ frame_extraction: {dataset.frame_extraction}
+ frame_stride: {dataset.frame_stride}
+ frame_sample: {dataset.frame_sample}
+ \n"""
+ ),
+ " ",
+ )
+ logger.info(f"{info}")
+
+ # make buckets first because it determines the length of dataset
+ # and set the same seed for all datasets
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
+ for i, dataset in enumerate(datasets):
+ # logger.info(f"[Dataset {i}]")
+ dataset.set_seed(seed)
+ if training:
+ dataset.prepare_for_training()
+
+ return DatasetGroup(datasets)
+
+
+def load_user_config(file: str) -> dict:
+ file: Path = Path(file)
+ if not file.is_file():
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
+
+ if file.name.lower().endswith(".json"):
+ try:
+ with open(file, "r") as f:
+ config = json.load(f)
+ except Exception:
+ logger.error(
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
+ )
+ raise
+ elif file.name.lower().endswith(".toml"):
+ try:
+ config = toml.load(file)
+ except Exception:
+ logger.error(
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
+ )
+ raise
+ else:
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
+
+ return config
+
+
+# for config test
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("dataset_config")
+ config_args, remain = parser.parse_known_args()
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--debug_dataset", action="store_true")
+ argparse_namespace = parser.parse_args(remain)
+
+ logger.info("[argparse_namespace]")
+ logger.info(f"{vars(argparse_namespace)}")
+
+ user_config = load_user_config(config_args.dataset_config)
+
+ logger.info("")
+ logger.info("[user_config]")
+ logger.info(f"{user_config}")
+
+ sanitizer = ConfigSanitizer()
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
+
+ logger.info("")
+ logger.info("[sanitized_user_config]")
+ logger.info(f"{sanitized_user_config}")
+
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
+
+ logger.info("")
+ logger.info("[blueprint]")
+ logger.info(f"{blueprint}")
+
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
diff --git a/dataset/dataset_config.md b/dataset/dataset_config.md
new file mode 100644
index 0000000000000000000000000000000000000000..e91bf2d5b5b58a45862417dfa3468d58fb1c899b
--- /dev/null
+++ b/dataset/dataset_config.md
@@ -0,0 +1,293 @@
+## Dataset Configuration
+
+Please create a TOML file for dataset configuration.
+
+Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
+
+### Sample for Image Dataset with Caption Text Files
+
+```toml
+# resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
+
+# general configurations
+[general]
+resolution = [960, 544]
+caption_extension = ".txt"
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+image_directory = "/path/to/image_dir"
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+### Sample for Image Dataset with Metadata JSONL File
+
+```toml
+# resolution, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
+# caption_extension is not required for metadata jsonl file
+# cache_directory is required for each dataset with metadata jsonl file
+
+# general configurations
+[general]
+resolution = [960, 544]
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+image_jsonl_file = "/path/to/metadata.jsonl"
+cache_directory = "/path/to/cache_directory"
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+JSONL file format for metadata:
+
+```json
+{"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
+{"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
+```
+
+### Sample for Video Dataset with Caption Text Files
+
+```toml
+# resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
+
+# general configurations
+[general]
+resolution = [960, 544]
+caption_extension = ".txt"
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+video_directory = "/path/to/video_dir"
+target_frames = [1, 25, 45]
+frame_extraction = "head"
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+### Sample for Video Dataset with Metadata JSONL File
+
+```toml
+# resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
+# caption_extension is not required for metadata jsonl file
+# cache_directory is required for each dataset with metadata jsonl file
+
+# general configurations
+[general]
+resolution = [960, 544]
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+video_jsonl_file = "/path/to/metadata.jsonl"
+target_frames = [1, 25, 45]
+frame_extraction = "head"
+cache_directory = "/path/to/cache_directory"
+
+# same metadata jsonl file can be used for multiple datasets
+[[datasets]]
+video_jsonl_file = "/path/to/metadata.jsonl"
+target_frames = [1]
+frame_stride = 10
+cache_directory = "/path/to/cache_directory"
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+JSONL file format for metadata:
+
+```json
+{"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
+{"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
+```
+
+### fame_extraction Options
+
+- `head`: Extract the first N frames from the video.
+- `chunk`: Extract frames by splitting the video into chunks of N frames.
+- `slide`: Extract frames from the video with a stride of `frame_stride`.
+- `uniform`: Extract `frame_sample` samples uniformly from the video.
+
+For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
+
+```
+Original Video, 40 frames: x = frame, o = no frame
+oooooooooooooooooooooooooooooooooooooooo
+
+head, target_frames = [1, 13, 25] -> extract head frames:
+xooooooooooooooooooooooooooooooooooooooo
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+
+chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+oooooooooooooxxxxxxxxxxxxxoooooooooooooo
+ooooooooooooooooooooooooooxxxxxxxxxxxxxo
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+
+NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
+
+slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
+xooooooooooooooooooooooooooooooooooooooo
+ooooooooooxooooooooooooooooooooooooooooo
+ooooooooooooooooooooxooooooooooooooooooo
+ooooooooooooooooooooooooooooooxooooooooo
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+ooooooooooxxxxxxxxxxxxxooooooooooooooooo
+ooooooooooooooooooooxxxxxxxxxxxxxooooooo
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
+
+uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
+xooooooooooooooooooooooooooooooooooooooo
+oooooooooooooxoooooooooooooooooooooooooo
+oooooooooooooooooooooooooxoooooooooooooo
+ooooooooooooooooooooooooooooooooooooooox
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+oooooooooxxxxxxxxxxxxxoooooooooooooooooo
+ooooooooooooooooooxxxxxxxxxxxxxooooooooo
+oooooooooooooooooooooooooooxxxxxxxxxxxxx
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
+ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
+oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
+```
+
+## Specifications
+
+```toml
+# general configurations
+[general]
+resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
+caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
+batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
+enable_bucket = true # optional, default is false. Enable bucketing for datasets
+bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
+
+### Image Dataset
+
+# sample image dataset with caption text files
+[[datasets]]
+image_directory = "/path/to/image_dir"
+caption_extension = ".txt" # required for caption text files, if general caption extension is not set
+resolution = [960, 544] # required if general resolution is not set
+batch_size = 4 # optional, overwrite the default batch size
+enable_bucket = false # optional, overwrite the default bucketing setting
+bucket_no_upscale = true # optional, overwrite the default bucketing setting
+cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
+
+# sample image dataset with metadata **jsonl** file
+[[datasets]]
+image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
+resolution = [960, 544] # required if general resolution is not set
+cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
+# caption_extension is not required for metadata jsonl file
+# batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
+
+### Video Dataset
+
+# sample video dataset with caption text files
+[[datasets]]
+video_directory = "/path/to/video_dir"
+caption_extension = ".txt" # required for caption text files, if general caption extension is not set
+resolution = [960, 544] # required if general resolution is not set
+
+target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
+
+# NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
+
+frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
+frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
+frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
+# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
+
+# sample video dataset with metadata jsonl file
+[[datasets]]
+video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
+
+target_frames = [1, 79]
+
+cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
+# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
+```
+
+
+
+The metadata with .json file will be supported in the near future.
+
+
+
+
\ No newline at end of file
diff --git a/dataset/image_video_dataset.py b/dataset/image_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b6111ef8d3978ce4245fdd158cfe70aab9cb29
--- /dev/null
+++ b/dataset/image_video_dataset.py
@@ -0,0 +1,1255 @@
+from concurrent.futures import ThreadPoolExecutor
+import glob
+import json
+import math
+import os
+import random
+import time
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+from safetensors.torch import save_file, load_file
+from safetensors import safe_open
+from PIL import Image
+import cv2
+import av
+
+from utils import safetensors_utils
+from utils.model_utils import dtype_to_str
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
+
+try:
+ import pillow_avif
+
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
+except:
+ pass
+
+# JPEG-XL on Linux
+try:
+ from jxlpy import JXLImagePlugin
+
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
+except:
+ pass
+
+# JPEG-XL on Windows
+try:
+ import pillow_jxl
+
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
+except:
+ pass
+
+VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] # some of them are not tested
+
+ARCHITECTURE_HUNYUAN_VIDEO = "hv"
+
+
+def glob_images(directory, base="*"):
+ img_paths = []
+ for ext in IMAGE_EXTENSIONS:
+ if base == "*":
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
+ else:
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
+ img_paths = list(set(img_paths)) # remove duplicates
+ img_paths.sort()
+ return img_paths
+
+
+def glob_videos(directory, base="*"):
+ video_paths = []
+ for ext in VIDEO_EXTENSIONS:
+ if base == "*":
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
+ else:
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
+ video_paths = list(set(video_paths)) # remove duplicates
+ video_paths.sort()
+ return video_paths
+
+
+def divisible_by(num: int, divisor: int) -> int:
+ return num - num % divisor
+
+
+def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
+ """
+ Resize the image to the bucket resolution.
+ """
+ is_pil_image = isinstance(image, Image.Image)
+ if is_pil_image:
+ image_width, image_height = image.size
+ else:
+ image_height, image_width = image.shape[:2]
+
+ if bucket_reso == (image_width, image_height):
+ return np.array(image) if is_pil_image else image
+
+ bucket_width, bucket_height = bucket_reso
+ if bucket_width == image_width or bucket_height == image_height:
+ image = np.array(image) if is_pil_image else image
+ else:
+ # resize the image to the bucket resolution to match the short side
+ scale_width = bucket_width / image_width
+ scale_height = bucket_height / image_height
+ scale = max(scale_width, scale_height)
+ image_width = int(image_width * scale + 0.5)
+ image_height = int(image_height * scale + 0.5)
+
+ if scale > 1:
+ image = Image.fromarray(image) if not is_pil_image else image
+ image = image.resize((image_width, image_height), Image.LANCZOS)
+ image = np.array(image)
+ else:
+ image = np.array(image) if is_pil_image else image
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
+
+ # crop the image to the bucket resolution
+ crop_left = (image_width - bucket_width) // 2
+ crop_top = (image_height - bucket_height) // 2
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
+ return image
+
+
+class ItemInfo:
+ def __init__(
+ self,
+ item_key: str,
+ caption: str,
+ original_size: tuple[int, int],
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
+ frame_count: Optional[int] = None,
+ content: Optional[np.ndarray] = None,
+ latent_cache_path: Optional[str] = None,
+ ) -> None:
+ self.item_key = item_key
+ self.caption = caption
+ self.original_size = original_size
+ self.bucket_size = bucket_size
+ self.frame_count = frame_count
+ self.content = content
+ self.latent_cache_path = latent_cache_path
+ self.text_encoder_output_cache_path: Optional[str] = None
+
+ def __str__(self) -> str:
+ return (
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
+ )
+
+
+def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
+ metadata = {
+ "architecture": "hunyuan_video",
+ "width": f"{item_info.original_size[0]}",
+ "height": f"{item_info.original_size[1]}",
+ "format_version": "1.0.0",
+ }
+ if item_info.frame_count is not None:
+ metadata["frame_count"] = f"{item_info.frame_count}"
+
+ _, F, H, W = latent.shape
+ dtype_str = dtype_to_str(latent.dtype)
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
+
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
+ os.makedirs(latent_dir, exist_ok=True)
+
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
+
+
+def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
+ assert (
+ embed.dim() == 1 or embed.dim() == 2
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
+ metadata = {
+ "architecture": "hunyuan_video",
+ "caption1": item_info.caption,
+ "format_version": "1.0.0",
+ }
+
+ sd = {}
+ if os.path.exists(item_info.text_encoder_output_cache_path):
+ # load existing cache and update metadata
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
+ existing_metadata = f.metadata()
+ for key in f.keys():
+ sd[key] = f.get_tensor(key)
+
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
+ if existing_metadata["caption1"] != metadata["caption1"]:
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
+ # TODO verify format_version
+
+ existing_metadata.pop("caption1", None)
+ existing_metadata.pop("format_version", None)
+ metadata.update(existing_metadata) # copy existing metadata
+ else:
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
+
+ dtype_str = dtype_to_str(embed.dtype)
+ text_encoder_type = "llm" if is_llm else "clipL"
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
+ if mask is not None:
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
+
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
+
+
+class BucketSelector:
+ RESOLUTION_STEPS_HUNYUAN = 16
+
+ def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False):
+ self.resolution = resolution
+ self.bucket_area = resolution[0] * resolution[1]
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
+
+ if not enable_bucket:
+ # only define one bucket
+ self.bucket_resolutions = [resolution]
+ self.no_upscale = False
+ else:
+ # prepare bucket resolution
+ self.no_upscale = no_upscale
+ sqrt_size = int(math.sqrt(self.bucket_area))
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
+ self.bucket_resolutions = []
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
+ self.bucket_resolutions.append((w, h))
+ self.bucket_resolutions.append((h, w))
+
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
+ self.bucket_resolutions.sort()
+
+ # calculate aspect ratio to find the nearest resolution
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
+
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
+ """
+ return the bucket resolution for the given image size, (width, height)
+ """
+ area = image_size[0] * image_size[1]
+ if self.no_upscale and area <= self.bucket_area:
+ w, h = image_size
+ w = divisible_by(w, self.reso_steps)
+ h = divisible_by(h, self.reso_steps)
+ return w, h
+
+ aspect_ratio = image_size[0] / image_size[1]
+ ar_errors = self.aspect_ratios - aspect_ratio
+ bucket_id = np.abs(ar_errors).argmin()
+ return self.bucket_resolutions[bucket_id]
+
+
+def load_video(
+ video_path: str,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+) -> list[np.ndarray]:
+ container = av.open(video_path)
+ video = []
+ bucket_reso = None
+ for i, frame in enumerate(container.decode(video=0)):
+ if start_frame is not None and i < start_frame:
+ continue
+ if end_frame is not None and i >= end_frame:
+ break
+ frame = frame.to_image()
+
+ if bucket_selector is not None and bucket_reso is None:
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
+
+ if bucket_reso is not None:
+ frame = resize_image_to_bucket(frame, bucket_reso)
+ else:
+ frame = np.array(frame)
+
+ video.append(frame)
+ container.close()
+ return video
+
+
+class BucketBatchManager:
+
+ def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
+ self.batch_size = batch_size
+ self.buckets = bucketed_item_info
+ self.bucket_resos = list(self.buckets.keys())
+ self.bucket_resos.sort()
+
+ self.bucket_batch_indices = []
+ for bucket_reso in self.bucket_resos:
+ bucket = self.buckets[bucket_reso]
+ num_batches = math.ceil(len(bucket) / self.batch_size)
+ for i in range(num_batches):
+ self.bucket_batch_indices.append((bucket_reso, i))
+
+ self.shuffle()
+
+ def show_bucket_info(self):
+ for bucket_reso in self.bucket_resos:
+ bucket = self.buckets[bucket_reso]
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
+
+ logger.info(f"total batches: {len(self)}")
+
+ def shuffle(self):
+ for bucket in self.buckets.values():
+ random.shuffle(bucket)
+ random.shuffle(self.bucket_batch_indices)
+
+ def __len__(self):
+ return len(self.bucket_batch_indices)
+
+ def __getitem__(self, idx):
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
+ bucket = self.buckets[bucket_reso]
+ start = batch_idx * self.batch_size
+ end = min(start + self.batch_size, len(bucket))
+
+ latents = []
+ llm_embeds = []
+ llm_masks = []
+ clip_l_embeds = []
+ for item_info in bucket[start:end]:
+ sd = load_file(item_info.latent_cache_path)
+ latent = None
+ for key in sd.keys():
+ if key.startswith("latents_"):
+ latent = sd[key]
+ break
+ latents.append(latent)
+
+ sd = load_file(item_info.text_encoder_output_cache_path)
+ llm_embed = llm_mask = clip_l_embed = None
+ for key in sd.keys():
+ if key.startswith("llm_mask"):
+ llm_mask = sd[key]
+ elif key.startswith("llm_"):
+ llm_embed = sd[key]
+ elif key.startswith("clipL_mask"):
+ pass
+ elif key.startswith("clipL_"):
+ clip_l_embed = sd[key]
+ llm_embeds.append(llm_embed)
+ llm_masks.append(llm_mask)
+ clip_l_embeds.append(clip_l_embed)
+
+ latents = torch.stack(latents)
+ llm_embeds = torch.stack(llm_embeds)
+ llm_masks = torch.stack(llm_masks)
+ clip_l_embeds = torch.stack(clip_l_embeds)
+
+ return latents, llm_embeds, llm_masks, clip_l_embeds
+
+
+class ContentDatasource:
+ def __init__(self):
+ self.caption_only = False
+
+ def set_caption_only(self, caption_only: bool):
+ self.caption_only = caption_only
+
+ def is_indexable(self):
+ return False
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ """
+ Returns caption. May not be called if is_indexable() returns False.
+ """
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __next__(self):
+ raise NotImplementedError
+
+
+class ImageDatasource(ContentDatasource):
+ def __init__(self):
+ super().__init__()
+
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
+ """
+ Returns image data as a tuple of image path, image, and caption for the given index.
+ Key must be unique and valid as a file name.
+ May not be called if is_indexable() returns False.
+ """
+ raise NotImplementedError
+
+
+class ImageDirectoryDatasource(ImageDatasource):
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
+ super().__init__()
+ self.image_directory = image_directory
+ self.caption_extension = caption_extension
+ self.current_idx = 0
+
+ # glob images
+ logger.info(f"glob images in {self.image_directory}")
+ self.image_paths = glob_images(self.image_directory)
+ logger.info(f"found {len(self.image_paths)} images")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
+ image_path = self.image_paths[idx]
+ image = Image.open(image_path).convert("RGB")
+
+ _, caption = self.get_caption(idx)
+
+ return image_path, image, caption
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ image_path = self.image_paths[idx]
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
+ with open(caption_path, "r", encoding="utf-8") as f:
+ caption = f.read().strip()
+ return image_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self) -> callable:
+ """
+ Returns a fetcher function that returns image data.
+ """
+ if self.current_idx >= len(self.image_paths):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+ else:
+
+ def create_image_fetcher(index):
+ return lambda: self.get_image_data(index)
+
+ fetcher = create_image_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class ImageJsonlDatasource(ImageDatasource):
+ def __init__(self, image_jsonl_file: str):
+ super().__init__()
+ self.image_jsonl_file = image_jsonl_file
+ self.current_idx = 0
+
+ # load jsonl
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
+ self.data = []
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ self.data.append(data)
+ logger.info(f"loaded {len(self.data)} images")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.data)
+
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
+ data = self.data[idx]
+ image_path = data["image_path"]
+ image = Image.open(image_path).convert("RGB")
+
+ caption = data["caption"]
+
+ return image_path, image, caption
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ data = self.data[idx]
+ image_path = data["image_path"]
+ caption = data["caption"]
+ return image_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self) -> callable:
+ if self.current_idx >= len(self.data):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+
+ else:
+
+ def create_fetcher(index):
+ return lambda: self.get_image_data(index)
+
+ fetcher = create_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class VideoDatasource(ContentDatasource):
+ def __init__(self):
+ super().__init__()
+
+ # None means all frames
+ self.start_frame = None
+ self.end_frame = None
+
+ self.bucket_selector = None
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def get_video_data_from_path(
+ self,
+ video_path: str,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> tuple[str, list[Image.Image], str]:
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
+
+ start_frame = start_frame if start_frame is not None else self.start_frame
+ end_frame = end_frame if end_frame is not None else self.end_frame
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
+
+ video = load_video(video_path, start_frame, end_frame, bucket_selector)
+ return video
+
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
+ self.start_frame = start_frame
+ self.end_frame = end_frame
+
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
+ self.bucket_selector = bucket_selector
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __next__(self):
+ raise NotImplementedError
+
+
+class VideoDirectoryDatasource(VideoDatasource):
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
+ super().__init__()
+ self.video_directory = video_directory
+ self.caption_extension = caption_extension
+ self.current_idx = 0
+
+ # glob images
+ logger.info(f"glob images in {self.video_directory}")
+ self.video_paths = glob_videos(self.video_directory)
+ logger.info(f"found {len(self.video_paths)} videos")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.video_paths)
+
+ def get_video_data(
+ self,
+ idx: int,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> tuple[str, list[Image.Image], str]:
+ video_path = self.video_paths[idx]
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
+
+ _, caption = self.get_caption(idx)
+
+ return video_path, video, caption
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ video_path = self.video_paths[idx]
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
+ with open(caption_path, "r", encoding="utf-8") as f:
+ caption = f.read().strip()
+ return video_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= len(self.video_paths):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+
+ else:
+
+ def create_fetcher(index):
+ return lambda: self.get_video_data(index)
+
+ fetcher = create_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class VideoJsonlDatasource(VideoDatasource):
+ def __init__(self, video_jsonl_file: str):
+ super().__init__()
+ self.video_jsonl_file = video_jsonl_file
+ self.current_idx = 0
+
+ # load jsonl
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
+ self.data = []
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ self.data.append(data)
+ logger.info(f"loaded {len(self.data)} videos")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.data)
+
+ def get_video_data(
+ self,
+ idx: int,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> tuple[str, list[Image.Image], str]:
+ data = self.data[idx]
+ video_path = data["video_path"]
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
+
+ caption = data["caption"]
+
+ return video_path, video, caption
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ data = self.data[idx]
+ video_path = data["video_path"]
+ caption = data["caption"]
+ return video_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= len(self.data):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+
+ else:
+
+ def create_fetcher(index):
+ return lambda: self.get_video_data(index)
+
+ fetcher = create_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class BaseDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ resolution: Tuple[int, int] = (960, 544),
+ caption_extension: Optional[str] = None,
+ batch_size: int = 1,
+ enable_bucket: bool = False,
+ bucket_no_upscale: bool = False,
+ cache_directory: Optional[str] = None,
+ debug_dataset: bool = False,
+ ):
+ self.resolution = resolution
+ self.caption_extension = caption_extension
+ self.batch_size = batch_size
+ self.enable_bucket = enable_bucket
+ self.bucket_no_upscale = bucket_no_upscale
+ self.cache_directory = cache_directory
+ self.debug_dataset = debug_dataset
+ self.seed = None
+ self.current_epoch = 0
+
+ if not self.enable_bucket:
+ self.bucket_no_upscale = False
+
+ def get_metadata(self) -> dict:
+ metadata = {
+ "resolution": self.resolution,
+ "caption_extension": self.caption_extension,
+ "batch_size_per_device": self.batch_size,
+ "enable_bucket": bool(self.enable_bucket),
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
+ }
+ return metadata
+
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
+ w, h = item_info.original_size
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")
+
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
+ return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")
+
+ def retrieve_latent_cache_batches(self, num_workers: int):
+ raise NotImplementedError
+
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
+ raise NotImplementedError
+
+ def prepare_for_training(self):
+ pass
+
+ def set_seed(self, seed: int):
+ self.seed = seed
+
+ def set_current_epoch(self, epoch):
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
+ if epoch > self.current_epoch:
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
+ num_epochs = epoch - self.current_epoch
+ for _ in range(num_epochs):
+ self.current_epoch += 1
+ self.shuffle_buckets()
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
+ else:
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
+ self.current_epoch = epoch
+
+ def set_current_step(self, step):
+ self.current_step = step
+
+ def set_max_train_steps(self, max_train_steps):
+ self.max_train_steps = max_train_steps
+
+ def shuffle_buckets(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ return NotImplementedError
+
+ def __getitem__(self, idx):
+ raise NotImplementedError
+
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
+ datasource.set_caption_only(True)
+ executor = ThreadPoolExecutor(max_workers=num_workers)
+
+ data: list[ItemInfo] = []
+ futures = []
+
+ def aggregate_future(consume_all: bool = False):
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
+ completed_futures = [future for future in futures if future.done()]
+ if len(completed_futures) == 0:
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
+ time.sleep(0.1)
+ continue
+ else:
+ break # submit batch if possible
+
+ for future in completed_futures:
+ item_key, caption = future.result()
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
+ data.append(item_info)
+
+ futures.remove(future)
+
+ def submit_batch(flush: bool = False):
+ nonlocal data
+ if len(data) >= batch_size or (len(data) > 0 and flush):
+ batch = data[0:batch_size]
+ if len(data) > batch_size:
+ data = data[batch_size:]
+ else:
+ data = []
+ return batch
+ return None
+
+ for fetch_op in datasource:
+ future = executor.submit(fetch_op)
+ futures.append(future)
+ aggregate_future()
+ while True:
+ batch = submit_batch()
+ if batch is None:
+ break
+ yield batch
+
+ aggregate_future(consume_all=True)
+ while True:
+ batch = submit_batch(flush=True)
+ if batch is None:
+ break
+ yield batch
+
+ executor.shutdown()
+
+
+class ImageDataset(BaseDataset):
+ def __init__(
+ self,
+ resolution: Tuple[int, int],
+ caption_extension: Optional[str],
+ batch_size: int,
+ enable_bucket: bool,
+ bucket_no_upscale: bool,
+ image_directory: Optional[str] = None,
+ image_jsonl_file: Optional[str] = None,
+ cache_directory: Optional[str] = None,
+ debug_dataset: bool = False,
+ ):
+ super(ImageDataset, self).__init__(
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
+ )
+ self.image_directory = image_directory
+ self.image_jsonl_file = image_jsonl_file
+ if image_directory is not None:
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
+ elif image_jsonl_file is not None:
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
+ else:
+ raise ValueError("image_directory or image_jsonl_file must be specified")
+
+ if self.cache_directory is None:
+ self.cache_directory = self.image_directory
+
+ self.batch_manager = None
+ self.num_train_items = 0
+
+ def get_metadata(self):
+ metadata = super().get_metadata()
+ if self.image_directory is not None:
+ metadata["image_directory"] = os.path.basename(self.image_directory)
+ if self.image_jsonl_file is not None:
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
+ return metadata
+
+ def get_total_image_count(self):
+ return len(self.datasource) if self.datasource.is_indexable() else None
+
+ def retrieve_latent_cache_batches(self, num_workers: int):
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
+ executor = ThreadPoolExecutor(max_workers=num_workers)
+
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
+ futures = []
+
+ def aggregate_future(consume_all: bool = False):
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
+ completed_futures = [future for future in futures if future.done()]
+ if len(completed_futures) == 0:
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
+ time.sleep(0.1)
+ continue
+ else:
+ break # submit batch if possible
+
+ for future in completed_futures:
+ original_size, item_key, image, caption = future.result()
+ bucket_height, bucket_width = image.shape[:2]
+ bucket_reso = (bucket_width, bucket_height)
+
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
+
+ if bucket_reso not in batches:
+ batches[bucket_reso] = []
+ batches[bucket_reso].append(item_info)
+
+ futures.remove(future)
+
+ def submit_batch(flush: bool = False):
+ for key in batches:
+ if len(batches[key]) >= self.batch_size or flush:
+ batch = batches[key][0 : self.batch_size]
+ if len(batches[key]) > self.batch_size:
+ batches[key] = batches[key][self.batch_size :]
+ else:
+ del batches[key]
+ return key, batch
+ return None, None
+
+ for fetch_op in self.datasource:
+
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
+ image_key, image, caption = op()
+ image: Image.Image
+ image_size = image.size
+
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
+ image = resize_image_to_bucket(image, bucket_reso)
+ return image_size, image_key, image, caption
+
+ future = executor.submit(fetch_and_resize, fetch_op)
+ futures.append(future)
+ aggregate_future()
+ while True:
+ key, batch = submit_batch()
+ if key is None:
+ break
+ yield key, batch
+
+ aggregate_future(consume_all=True)
+ while True:
+ key, batch = submit_batch(flush=True)
+ if key is None:
+ break
+ yield key, batch
+
+ executor.shutdown()
+
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
+
+ def prepare_for_training(self):
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
+
+ # glob cache files
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
+
+ # assign cache files to item info
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
+ for cache_file in latent_cache_files:
+ tokens = os.path.basename(cache_file).split("_")
+
+ image_size = tokens[-2] # 0000x0000
+ image_width, image_height = map(int, image_size.split("x"))
+ image_size = (image_width, image_height)
+
+ item_key = "_".join(tokens[:-2])
+ text_encoder_output_cache_file = os.path.join(
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
+ )
+ if not os.path.exists(text_encoder_output_cache_file):
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
+ continue
+
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
+
+ bucket = bucketed_item_info.get(bucket_reso, [])
+ bucket.append(item_info)
+ bucketed_item_info[bucket_reso] = bucket
+
+ # prepare batch manager
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
+ self.batch_manager.show_bucket_info()
+
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
+
+ def shuffle_buckets(self):
+ # set random seed for this epoch
+ random.seed(self.seed + self.current_epoch)
+ self.batch_manager.shuffle()
+
+ def __len__(self):
+ if self.batch_manager is None:
+ return 100 # dummy value
+ return len(self.batch_manager)
+
+ def __getitem__(self, idx):
+ return self.batch_manager[idx]
+
+
+class VideoDataset(BaseDataset):
+ def __init__(
+ self,
+ resolution: Tuple[int, int],
+ caption_extension: Optional[str],
+ batch_size: int,
+ enable_bucket: bool,
+ bucket_no_upscale: bool,
+ frame_extraction: Optional[str] = "head",
+ frame_stride: Optional[int] = 1,
+ frame_sample: Optional[int] = 1,
+ target_frames: Optional[list[int]] = None,
+ video_directory: Optional[str] = None,
+ video_jsonl_file: Optional[str] = None,
+ cache_directory: Optional[str] = None,
+ debug_dataset: bool = False,
+ ):
+ super(VideoDataset, self).__init__(
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
+ )
+ self.video_directory = video_directory
+ self.video_jsonl_file = video_jsonl_file
+ self.target_frames = target_frames
+ self.frame_extraction = frame_extraction
+ self.frame_stride = frame_stride
+ self.frame_sample = frame_sample
+
+ if video_directory is not None:
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
+ elif video_jsonl_file is not None:
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
+
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
+ self.frame_extraction = "head"
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
+ if self.frame_extraction == "head":
+ # head extraction. we can limit the number of frames to be extracted
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
+
+ if self.cache_directory is None:
+ self.cache_directory = self.video_directory
+
+ self.batch_manager = None
+ self.num_train_items = 0
+
+ def get_metadata(self):
+ metadata = super().get_metadata()
+ if self.video_directory is not None:
+ metadata["video_directory"] = os.path.basename(self.video_directory)
+ if self.video_jsonl_file is not None:
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
+ metadata["frame_extraction"] = self.frame_extraction
+ metadata["frame_stride"] = self.frame_stride
+ metadata["frame_sample"] = self.frame_sample
+ metadata["target_frames"] = self.target_frames
+ return metadata
+
+ def retrieve_latent_cache_batches(self, num_workers: int):
+ buckset_selector = BucketSelector(self.resolution)
+ self.datasource.set_bucket_selector(buckset_selector)
+
+ executor = ThreadPoolExecutor(max_workers=num_workers)
+
+ # key: (width, height, frame_count), value: [ItemInfo]
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
+ futures = []
+
+ def aggregate_future(consume_all: bool = False):
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
+ completed_futures = [future for future in futures if future.done()]
+ if len(completed_futures) == 0:
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
+ time.sleep(0.1)
+ continue
+ else:
+ break # submit batch if possible
+
+ for future in completed_futures:
+ original_frame_size, video_key, video, caption = future.result()
+
+ frame_count = len(video)
+ video = np.stack(video, axis=0)
+ height, width = video.shape[1:3]
+ bucket_reso = (width, height) # already resized
+
+ crop_pos_and_frames = []
+ if self.frame_extraction == "head":
+ for target_frame in self.target_frames:
+ if frame_count >= target_frame:
+ crop_pos_and_frames.append((0, target_frame))
+ elif self.frame_extraction == "chunk":
+ # split by target_frames
+ for target_frame in self.target_frames:
+ for i in range(0, frame_count, target_frame):
+ if i + target_frame <= frame_count:
+ crop_pos_and_frames.append((i, target_frame))
+ elif self.frame_extraction == "slide":
+ # slide window
+ for target_frame in self.target_frames:
+ if frame_count >= target_frame:
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
+ crop_pos_and_frames.append((i, target_frame))
+ elif self.frame_extraction == "uniform":
+ # select N frames uniformly
+ for target_frame in self.target_frames:
+ if frame_count >= target_frame:
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
+ for i in frame_indices:
+ crop_pos_and_frames.append((i, target_frame))
+ else:
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
+
+ for crop_pos, target_frame in crop_pos_and_frames:
+ cropped_video = video[crop_pos : crop_pos + target_frame]
+ body, ext = os.path.splitext(video_key)
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
+
+ item_info = ItemInfo(
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
+ )
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
+
+ batch = batches.get(batch_key, [])
+ batch.append(item_info)
+ batches[batch_key] = batch
+
+ futures.remove(future)
+
+ def submit_batch(flush: bool = False):
+ for key in batches:
+ if len(batches[key]) >= self.batch_size or flush:
+ batch = batches[key][0 : self.batch_size]
+ if len(batches[key]) > self.batch_size:
+ batches[key] = batches[key][self.batch_size :]
+ else:
+ del batches[key]
+ return key, batch
+ return None, None
+
+ for operator in self.datasource:
+
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
+ video_key, video, caption = op()
+ video: list[np.ndarray]
+ frame_size = (video[0].shape[1], video[0].shape[0])
+
+ # resize if necessary
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
+
+ return frame_size, video_key, video, caption
+
+ future = executor.submit(fetch_and_resize, operator)
+ futures.append(future)
+ aggregate_future()
+ while True:
+ key, batch = submit_batch()
+ if key is None:
+ break
+ yield key, batch
+
+ aggregate_future(consume_all=True)
+ while True:
+ key, batch = submit_batch(flush=True)
+ if key is None:
+ break
+ yield key, batch
+
+ executor.shutdown()
+
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
+
+ def prepare_for_training(self):
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
+
+ # glob cache files
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
+
+ # assign cache files to item info
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
+ for cache_file in latent_cache_files:
+ tokens = os.path.basename(cache_file).split("_")
+
+ image_size = tokens[-2] # 0000x0000
+ image_width, image_height = map(int, image_size.split("x"))
+ image_size = (image_width, image_height)
+
+ frame_pos, frame_count = tokens[-3].split("-")
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
+
+ item_key = "_".join(tokens[:-3])
+ text_encoder_output_cache_file = os.path.join(
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
+ )
+ if not os.path.exists(text_encoder_output_cache_file):
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
+ continue
+
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
+ bucket_reso = (*bucket_reso, frame_count)
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
+
+ bucket = bucketed_item_info.get(bucket_reso, [])
+ bucket.append(item_info)
+ bucketed_item_info[bucket_reso] = bucket
+
+ # prepare batch manager
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
+ self.batch_manager.show_bucket_info()
+
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
+
+ def shuffle_buckets(self):
+ # set random seed for this epoch
+ random.seed(self.seed + self.current_epoch)
+ self.batch_manager.shuffle()
+
+ def __len__(self):
+ if self.batch_manager is None:
+ return 100 # dummy value
+ return len(self.batch_manager)
+
+ def __getitem__(self, idx):
+ return self.batch_manager[idx]
+
+
+class DatasetGroup(torch.utils.data.ConcatDataset):
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
+ super().__init__(datasets)
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
+ self.num_train_items = 0
+ for dataset in self.datasets:
+ self.num_train_items += dataset.num_train_items
+
+ def set_current_epoch(self, epoch):
+ for dataset in self.datasets:
+ dataset.set_current_epoch(epoch)
+
+ def set_current_step(self, step):
+ for dataset in self.datasets:
+ dataset.set_current_step(step)
+
+ def set_max_train_steps(self, max_train_steps):
+ for dataset in self.datasets:
+ dataset.set_max_train_steps(max_train_steps)
diff --git a/hunyuan_model/__init__.py b/hunyuan_model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hunyuan_model/activation_layers.py b/hunyuan_model/activation_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8774c26ceef6081482ca0dbbf930b207d4ac03b
--- /dev/null
+++ b/hunyuan_model/activation_layers.py
@@ -0,0 +1,23 @@
+import torch.nn as nn
+
+
+def get_activation_layer(act_type):
+ """get activation layer
+
+ Args:
+ act_type (str): the activation type
+
+ Returns:
+ torch.nn.functional: the activation layer
+ """
+ if act_type == "gelu":
+ return lambda: nn.GELU()
+ elif act_type == "gelu_tanh":
+ # Approximate `tanh` requires torch >= 1.13
+ return lambda: nn.GELU(approximate="tanh")
+ elif act_type == "relu":
+ return nn.ReLU
+ elif act_type == "silu":
+ return nn.SiLU
+ else:
+ raise ValueError(f"Unknown activation type: {act_type}")
diff --git a/hunyuan_model/attention.py b/hunyuan_model/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb8d812445ef94b1bb94c0a779a4b1f3c8ec2349
--- /dev/null
+++ b/hunyuan_model/attention.py
@@ -0,0 +1,230 @@
+import importlib.metadata
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ import flash_attn
+ from flash_attn.flash_attn_interface import _flash_attn_forward
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
+except ImportError:
+ flash_attn = None
+ flash_attn_varlen_func = None
+ _flash_attn_forward = None
+
+try:
+ print(f"Trying to import sageattention")
+ from sageattention import sageattn_varlen
+
+ print("Successfully imported sageattention")
+except ImportError:
+ print(f"Failed to import flash_attn and sageattention")
+ sageattn_varlen = None
+
+MEMORY_LAYOUT = {
+ "flash": (
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
+ lambda x: x,
+ ),
+ "sageattn": (
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
+ lambda x: x,
+ ),
+ "torch": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+ "vanilla": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+}
+
+
+def get_cu_seqlens(text_mask, img_len):
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
+
+ Args:
+ text_mask (torch.Tensor): the mask of text
+ img_len (int): the length of image
+
+ Returns:
+ torch.Tensor: the calculated cu_seqlens for flash attention
+ """
+ batch_size = text_mask.shape[0]
+ text_len = text_mask.sum(dim=1)
+ max_len = text_mask.shape[1] + img_len
+
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
+
+ for i in range(batch_size):
+ s = text_len[i] + img_len
+ s1 = i * max_len + s
+ s2 = (i + 1) * max_len
+ cu_seqlens[2 * i + 1] = s1
+ cu_seqlens[2 * i + 2] = s2
+
+ return cu_seqlens
+
+
+def attention(
+ q_or_qkv_list,
+ k=None,
+ v=None,
+ mode="flash",
+ drop_rate=0,
+ attn_mask=None,
+ causal=False,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ batch_size=1,
+):
+ """
+ Perform QKV self attention.
+
+ Args:
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
+ drop_rate (float): Dropout rate in attention map. (default: 0)
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
+ (default: None)
+ causal (bool): Whether to use causal attention. (default: False)
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into q.
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into kv.
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
+
+ Returns:
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
+ """
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
+ q = pre_attn_layout(q)
+ k = pre_attn_layout(k)
+ v = pre_attn_layout(v)
+
+ if mode == "torch":
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(q.dtype)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
+ if type(q_or_qkv_list) == list:
+ q_or_qkv_list.clear()
+ del q, k, v
+ del attn_mask
+ elif mode == "flash":
+ x = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ )
+ if type(q_or_qkv_list) == list:
+ q_or_qkv_list.clear()
+ del q, k, v
+ # x with shape [(bxs), a, d]
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
+ elif mode == "sageattn":
+ x = sageattn_varlen(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ )
+ if type(q_or_qkv_list) == list:
+ q_or_qkv_list.clear()
+ del q, k, v
+ # x with shape [(bxs), a, d]
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
+ elif mode == "vanilla":
+ scale_factor = 1 / math.sqrt(q.size(-1))
+
+ b, a, s, _ = q.shape
+ s1 = k.size(2)
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
+ if causal:
+ # Only applied to self attention
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(q.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
+ attn += attn_bias
+ attn = attn.softmax(dim=-1)
+ attn = torch.dropout(attn, p=drop_rate, train=True)
+ x = attn @ v
+ else:
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
+
+ x = post_attn_layout(x)
+ b, s, a, d = x.shape
+ out = x.reshape(b, s, -1)
+ return out
+
+
+def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
+ attn1 = hybrid_seq_parallel_attn(
+ None,
+ q[:, :img_q_len, :, :],
+ k[:, :img_kv_len, :, :],
+ v[:, :img_kv_len, :, :],
+ dropout_p=0.0,
+ causal=False,
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
+ joint_strategy="rear",
+ )
+ if flash_attn.__version__ >= "2.7.0":
+ attn2, *_ = _flash_attn_forward(
+ q[:, cu_seqlens_q[1] :],
+ k[:, cu_seqlens_kv[1] :],
+ v[:, cu_seqlens_kv[1] :],
+ dropout_p=0.0,
+ softmax_scale=q.shape[-1] ** (-0.5),
+ causal=False,
+ window_size_left=-1,
+ window_size_right=-1,
+ softcap=0.0,
+ alibi_slopes=None,
+ return_softmax=False,
+ )
+ else:
+ attn2, *_ = _flash_attn_forward(
+ q[:, cu_seqlens_q[1] :],
+ k[:, cu_seqlens_kv[1] :],
+ v[:, cu_seqlens_kv[1] :],
+ dropout_p=0.0,
+ softmax_scale=q.shape[-1] ** (-0.5),
+ causal=False,
+ window_size=(-1, -1),
+ softcap=0.0,
+ alibi_slopes=None,
+ return_softmax=False,
+ )
+ attn = torch.cat([attn1, attn2], dim=1)
+ b, s, a, d = attn.shape
+ attn = attn.reshape(b, s, -1)
+
+ return attn
diff --git a/hunyuan_model/autoencoder_kl_causal_3d.py b/hunyuan_model/autoencoder_kl_causal_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30f250f8f59167f291abd5e705061641020e52b
--- /dev/null
+++ b/hunyuan_model/autoencoder_kl_causal_3d.py
@@ -0,0 +1,609 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+from typing import Dict, Optional, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+
+try:
+ # This diffusers is modified and packed in the mirror.
+ from diffusers.loaders import FromOriginalVAEMixin
+except ImportError:
+ # Use this to be compatible with the original diffusers.
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
+
+
+@dataclass
+class DecoderOutput2(BaseOutput):
+ sample: torch.FloatTensor
+ posterior: Optional[DiagonalGaussianDistribution] = None
+
+
+class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
+ r"""
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ sample_tsize: int = 64,
+ scaling_factor: float = 0.18215,
+ force_upcast: float = True,
+ spatial_compression_ratio: int = 8,
+ time_compression_ratio: int = 4,
+ mid_block_add_attention: bool = True,
+ ):
+ super().__init__()
+
+ self.time_compression_ratio = time_compression_ratio
+
+ self.encoder = EncoderCausal3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ time_compression_ratio=time_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.decoder = DecoderCausal3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ time_compression_ratio=time_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
+
+ self.use_slicing = False
+ self.use_spatial_tiling = False
+ self.use_temporal_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_tsize = sample_tsize
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
+
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
+ module.gradient_checkpointing = value
+
+ def enable_temporal_tiling(self, use_tiling: bool = True):
+ self.use_temporal_tiling = use_tiling
+
+ def disable_temporal_tiling(self):
+ self.enable_temporal_tiling(False)
+
+ def enable_spatial_tiling(self, use_tiling: bool = True):
+ self.use_spatial_tiling = use_tiling
+
+ def disable_spatial_tiling(self):
+ self.enable_spatial_tiling(False)
+
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger videos.
+ """
+ self.enable_spatial_tiling(use_tiling)
+ self.enable_temporal_tiling(use_tiling)
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.disable_spatial_tiling()
+ self.disable_temporal_tiling()
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
+ # set chunk_size to CausalConv3d recursively
+ def set_chunk_size(module):
+ if hasattr(module, "chunk_size"):
+ module.chunk_size = chunk_size
+
+ self.apply(set_chunk_size)
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images/videos into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images/videos.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
+
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
+
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
+
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
+
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images/videos.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
+ return b
+
+ def spatial_tiled_encode(
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
+ ) -> AutoencoderKLOutput:
+ r"""Encode a batch of images/videos using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images/videos.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split video into tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[-2], overlap_size):
+ row = []
+ for j in range(0, x.shape[-1], overlap_size):
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ moments = torch.cat(result_rows, dim=-2)
+ if return_moments:
+ return moments
+
+ posterior = DiagonalGaussianDistribution(moments)
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Decode a batch of images/videos using a tiled decoder.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[-2], overlap_size):
+ row = []
+ for j in range(0, z.shape[-1], overlap_size):
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=-2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+
+ B, C, T, H, W = x.shape
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
+ t_limit = self.tile_latent_min_tsize - blend_extent
+
+ # Split the video into tiles and encode them separately.
+ row = []
+ for i in range(0, T, overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
+ if self.use_spatial_tiling and (
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
+ ):
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
+ else:
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ if i > 0:
+ tile = tile[:, :, 1:, :, :]
+ row.append(tile)
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :t_limit, :, :])
+ else:
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
+
+ moments = torch.cat(result_row, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ # Split z into overlapping tiles and decode them separately.
+
+ B, C, T, H, W = z.shape
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
+ t_limit = self.tile_sample_min_tsize - blend_extent
+
+ row = []
+ for i in range(0, T, overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
+ if self.use_spatial_tiling and (
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
+ ):
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
+ else:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ if i > 0:
+ decoded = decoded[:, :, 1:, :, :]
+ row.append(decoded)
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :t_limit, :, :])
+ else:
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
+
+ dec = torch.cat(result_row, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ return_posterior: bool = False,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ if return_posterior:
+ return (dec, posterior)
+ else:
+ return (dec,)
+ if return_posterior:
+ return DecoderOutput2(sample=dec, posterior=posterior)
+ else:
+ return DecoderOutput2(sample=dec)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
diff --git a/hunyuan_model/embed_layers.py b/hunyuan_model/embed_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e31ba9cc58d1aa05e0f17b919762f69bd693b5c0
--- /dev/null
+++ b/hunyuan_model/embed_layers.py
@@ -0,0 +1,132 @@
+import collections
+import math
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from .helpers import to_2tuple
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding
+
+ Image to Patch Embedding using Conv2d
+
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
+
+ Based on the impl in https://github.com/google-research/vision_transformer
+
+ Hacked together by / Copyright 2020 Ross Wightman
+
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
+ """
+
+ def __init__(
+ self,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ bias=True,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+ self.flatten = flatten
+
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
+ if bias:
+ nn.init.zeros_(self.proj.bias)
+
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
+
+
+class TextProjection(nn.Module):
+ """
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
+ self.act_1 = act_layer()
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ Args:
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ dim (int): the dimension of the output.
+ max_period (int): controls the minimum frequency of the embeddings.
+
+ Returns:
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
+
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ """
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ act_layer,
+ frequency_embedding_size=256,
+ max_period=10000,
+ out_size=None,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.frequency_embedding_size = frequency_embedding_size
+ self.max_period = max_period
+ if out_size is None:
+ out_size = hidden_size
+
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
+ act_layer(),
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
+ )
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
diff --git a/hunyuan_model/helpers.py b/hunyuan_model/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..72ab8cb1feba4ce7782f1ea841fd42c71be7b0d1
--- /dev/null
+++ b/hunyuan_model/helpers.py
@@ -0,0 +1,40 @@
+import collections.abc
+
+from itertools import repeat
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ x = tuple(x)
+ if len(x) == 1:
+ x = tuple(repeat(x[0], n))
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+
+
+def as_tuple(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ if x is None or isinstance(x, (int, float, str)):
+ return (x,)
+ else:
+ raise ValueError(f"Unknown type {type(x)}")
+
+
+def as_list_of_2tuple(x):
+ x = as_tuple(x)
+ if len(x) == 1:
+ x = (x[0], x[0])
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
+ lst = []
+ for i in range(0, len(x), 2):
+ lst.append((x[i], x[i + 1]))
+ return lst
diff --git a/hunyuan_model/mlp_layers.py b/hunyuan_model/mlp_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc9547a6a0ba80ab19a472a9ea7aef525f46613
--- /dev/null
+++ b/hunyuan_model/mlp_layers.py
@@ -0,0 +1,118 @@
+# Modified from timm library:
+# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from .modulate_layers import modulate
+from .helpers import to_2tuple
+
+
+class MLP(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ device=None,
+ dtype=None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ out_features = out_features or in_channels
+ hidden_channels = hidden_channels or in_channels
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
+ )
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = (
+ norm_layer(hidden_channels, **factory_kwargs)
+ if norm_layer is not None
+ else nn.Identity()
+ )
+ self.fc2 = linear_layer(
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
+ )
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+#
+class MLPEmbedder(nn.Module):
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class FinalLayer(nn.Module):
+ """The final layer of DiT."""
+
+ def __init__(
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ # Just use LayerNorm for the final layer
+ self.norm_final = nn.LayerNorm(
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
+ )
+ if isinstance(patch_size, int):
+ self.linear = nn.Linear(
+ hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True,
+ **factory_kwargs
+ )
+ else:
+ self.linear = nn.Linear(
+ hidden_size,
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
+ bias=True,
+ )
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ # Here we don't distinguish between the modulate types. Just use the simple one.
+ self.adaLN_modulation = nn.Sequential(
+ act_layer(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
+ x = self.linear(x)
+ return x
diff --git a/hunyuan_model/models.py b/hunyuan_model/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..604cfa46d91a173d94e3979506332cdda3542d44
--- /dev/null
+++ b/hunyuan_model/models.py
@@ -0,0 +1,997 @@
+import os
+from typing import Any, List, Tuple, Optional, Union, Dict
+import accelerate
+from einops import rearrange
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .activation_layers import get_activation_layer
+from .norm_layers import get_norm_layer
+from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
+from .attention import attention, parallel_attention, get_cu_seqlens
+from .posemb_layers import apply_rotary_emb
+from .mlp_layers import MLP, MLPEmbedder, FinalLayer
+from .modulate_layers import ModulateDiT, modulate, apply_gate
+from .token_refiner import SingleTokenRefiner
+from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
+from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
+
+from utils.safetensors_utils import MemoryEfficientSafeOpen
+
+
+class MMDoubleStreamBlock(nn.Module):
+ """
+ A multimodal dit block with seperate modulation for
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
+ (Flux.1): https://github.com/black-forest-labs/flux
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ mlp_width_ratio: float,
+ mlp_act_type: str = "gelu_tanh",
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qkv_bias: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ attn_mode: str = "flash",
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.attn_mode = attn_mode
+
+ self.deterministic = False
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+
+ self.img_mod = ModulateDiT(
+ hidden_size,
+ factor=6,
+ act_layer=get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.img_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.img_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.img_mlp = MLP(
+ hidden_size,
+ mlp_hidden_dim,
+ act_layer=get_activation_layer(mlp_act_type),
+ bias=True,
+ **factory_kwargs,
+ )
+
+ self.txt_mod = ModulateDiT(
+ hidden_size,
+ factor=6,
+ act_layer=get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ self.txt_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.txt_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.txt_mlp = MLP(
+ hidden_size,
+ mlp_hidden_dim,
+ act_layer=get_activation_layer(mlp_act_type),
+ bias=True,
+ **factory_kwargs,
+ )
+ self.hybrid_seq_parallel_attn = None
+
+ self.gradient_checkpointing = False
+
+ def enable_deterministic(self):
+ self.deterministic = True
+
+ def disable_deterministic(self):
+ self.deterministic = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def _forward(
+ self,
+ img: torch.Tensor,
+ txt: torch.Tensor,
+ vec: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_kv: Optional[int] = None,
+ freqs_cis: tuple = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
+ 6, dim=-1
+ )
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
+ 6, dim=-1
+ )
+
+ # Prepare image for attention.
+ img_modulated = self.img_norm1(img)
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
+ img_qkv = self.img_attn_qkv(img_modulated)
+ img_modulated = None
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ img_qkv = None
+ # Apply QK-Norm if needed
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
+
+ # Apply RoPE if needed.
+ if freqs_cis is not None:
+ img_q_shape = img_q.shape
+ img_k_shape = img_k.shape
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
+ assert (
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
+ # img_q, img_k = img_qq, img_kk
+
+ # Prepare txt for attention.
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
+ txt_modulated = None
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ txt_qkv = None
+ # Apply QK-Norm if needed.
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
+
+ # Run actual attention.
+ img_q_len = img_q.shape[1]
+ img_kv_len = img_k.shape[1]
+ batch_size = img_k.shape[0]
+ q = torch.cat((img_q, txt_q), dim=1)
+ img_q = txt_q = None
+ k = torch.cat((img_k, txt_k), dim=1)
+ img_k = txt_k = None
+ v = torch.cat((img_v, txt_v), dim=1)
+ img_v = txt_v = None
+
+ assert (
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
+
+ # attention computation start
+ if not self.hybrid_seq_parallel_attn:
+ l = [q, k, v]
+ q = k = v = None
+ attn = attention(
+ l,
+ mode=self.attn_mode,
+ attn_mask=attn_mask,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv,
+ batch_size=batch_size,
+ )
+ else:
+ attn = parallel_attention(
+ self.hybrid_seq_parallel_attn,
+ q,
+ k,
+ v,
+ img_q_len=img_q_len,
+ img_kv_len=img_kv_len,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ )
+
+ # attention computation end
+
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
+ attn = None
+
+ # Calculate the img bloks.
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
+ img_attn = None
+ img = img + apply_gate(
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
+ gate=img_mod2_gate,
+ )
+
+ # Calculate the txt bloks.
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
+ txt_attn = None
+ txt = txt + apply_gate(
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
+ gate=txt_mod2_gate,
+ )
+
+ return img, txt
+
+ # def forward(
+ # self,
+ # img: torch.Tensor,
+ # txt: torch.Tensor,
+ # vec: torch.Tensor,
+ # attn_mask: Optional[torch.Tensor] = None,
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
+ # max_seqlen_q: Optional[int] = None,
+ # max_seqlen_kv: Optional[int] = None,
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+class MMSingleStreamBlock(nn.Module):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
+ (Flux.1): https://github.com/black-forest-labs/flux
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_act_type: str = "gelu_tanh",
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qk_scale: float = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ attn_mode: str = "flash",
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.attn_mode = attn_mode
+
+ self.deterministic = False
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+ self.mlp_hidden_dim = mlp_hidden_dim
+ self.scale = qk_scale or head_dim**-0.5
+
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
+
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.mlp_act = get_activation_layer(mlp_act_type)()
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
+ self.hybrid_seq_parallel_attn = None
+
+ self.gradient_checkpointing = False
+
+ def enable_deterministic(self):
+ self.deterministic = True
+
+ def disable_deterministic(self):
+ self.deterministic = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ vec: torch.Tensor,
+ txt_len: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_kv: Optional[int] = None,
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ ) -> torch.Tensor:
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+ x_mod = None
+ # mlp = mlp.to("cpu", non_blocking=True)
+ # clean_memory_on_device(x.device)
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ qkv = None
+
+ # Apply QK-Norm if needed.
+ q = self.q_norm(q).to(v)
+ k = self.k_norm(k).to(v)
+
+ # Apply RoPE if needed.
+ if freqs_cis is not None:
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
+ q = k = None
+ img_q_shape = img_q.shape
+ img_k_shape = img_k.shape
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
+ assert (
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
+ # img_q, img_k = img_qq, img_kk
+ # del img_qq, img_kk
+ q = torch.cat((img_q, txt_q), dim=1)
+ k = torch.cat((img_k, txt_k), dim=1)
+ del img_q, txt_q, img_k, txt_k
+
+ # Compute attention.
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
+
+ # attention computation start
+ if not self.hybrid_seq_parallel_attn:
+ l = [q, k, v]
+ q = k = v = None
+ attn = attention(
+ l,
+ mode=self.attn_mode,
+ attn_mask=attn_mask,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv,
+ batch_size=x.shape[0],
+ )
+ else:
+ attn = parallel_attention(
+ self.hybrid_seq_parallel_attn,
+ q,
+ k,
+ v,
+ img_q_len=img_q.shape[1],
+ img_kv_len=img_k.shape[1],
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ )
+ # attention computation end
+
+ # Compute activation in mlp stream, cat again and run second linear layer.
+ # mlp = mlp.to(x.device)
+ mlp = self.mlp_act(mlp)
+ attn_mlp = torch.cat((attn, mlp), 2)
+ attn = None
+ mlp = None
+ output = self.linear2(attn_mlp)
+ attn_mlp = None
+ return x + apply_gate(output, gate=mod_gate)
+
+ # def forward(
+ # self,
+ # x: torch.Tensor,
+ # vec: torch.Tensor,
+ # txt_len: int,
+ # attn_mask: Optional[torch.Tensor] = None,
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
+ # max_seqlen_q: Optional[int] = None,
+ # max_seqlen_kv: Optional[int] = None,
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ # ) -> torch.Tensor:
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
+ """
+ HunyuanVideo Transformer backbone
+
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
+
+ Reference:
+ [1] Flux.1: https://github.com/black-forest-labs/flux
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
+
+ Parameters
+ ----------
+ args: argparse.Namespace
+ The arguments parsed by argparse.
+ patch_size: list
+ The size of the patch.
+ in_channels: int
+ The number of input channels.
+ out_channels: int
+ The number of output channels.
+ hidden_size: int
+ The hidden size of the transformer backbone.
+ heads_num: int
+ The number of attention heads.
+ mlp_width_ratio: float
+ The ratio of the hidden size of the MLP in the transformer block.
+ mlp_act_type: str
+ The activation function of the MLP in the transformer block.
+ depth_double_blocks: int
+ The number of transformer blocks in the double blocks.
+ depth_single_blocks: int
+ The number of transformer blocks in the single blocks.
+ rope_dim_list: list
+ The dimension of the rotary embedding for t, h, w.
+ qkv_bias: bool
+ Whether to use bias in the qkv linear layer.
+ qk_norm: bool
+ Whether to use qk norm.
+ qk_norm_type: str
+ The type of qk norm.
+ guidance_embed: bool
+ Whether to use guidance embedding for distillation.
+ text_projection: str
+ The type of the text projection, default is single_refiner.
+ use_attention_mask: bool
+ Whether to use attention mask for text encoder.
+ dtype: torch.dtype
+ The dtype of the model.
+ device: torch.device
+ The device of the model.
+ attn_mode: str
+ The mode of the attention, default is flash.
+ """
+
+ # @register_to_config
+ def __init__(
+ self,
+ text_states_dim: int,
+ text_states_dim_2: int,
+ patch_size: list = [1, 2, 2],
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
+ out_channels: int = None,
+ hidden_size: int = 3072,
+ heads_num: int = 24,
+ mlp_width_ratio: float = 4.0,
+ mlp_act_type: str = "gelu_tanh",
+ mm_double_blocks_depth: int = 20,
+ mm_single_blocks_depth: int = 40,
+ rope_dim_list: List[int] = [16, 56, 56],
+ qkv_bias: bool = True,
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ guidance_embed: bool = False, # For modulation.
+ text_projection: str = "single_refiner",
+ use_attention_mask: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ attn_mode: str = "flash",
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.unpatchify_channels = self.out_channels
+ self.guidance_embed = guidance_embed
+ self.rope_dim_list = rope_dim_list
+
+ # Text projection. Default to linear projection.
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
+ self.use_attention_mask = use_attention_mask
+ self.text_projection = text_projection
+
+ self.text_states_dim = text_states_dim
+ self.text_states_dim_2 = text_states_dim_2
+
+ if hidden_size % heads_num != 0:
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
+ pe_dim = hidden_size // heads_num
+ if sum(rope_dim_list) != pe_dim:
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+
+ self.attn_mode = attn_mode
+
+ # image projection
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
+
+ # text projection
+ if self.text_projection == "linear":
+ self.txt_in = TextProjection(
+ self.text_states_dim,
+ self.hidden_size,
+ get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ elif self.text_projection == "single_refiner":
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
+ else:
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
+
+ # time modulation
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
+
+ # text modulation
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
+
+ # guidance modulation
+ self.guidance_in = (
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
+ )
+
+ # double blocks
+ self.double_blocks = nn.ModuleList(
+ [
+ MMDoubleStreamBlock(
+ self.hidden_size,
+ self.heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_act_type=mlp_act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ qkv_bias=qkv_bias,
+ attn_mode=attn_mode,
+ **factory_kwargs,
+ )
+ for _ in range(mm_double_blocks_depth)
+ ]
+ )
+
+ # single blocks
+ self.single_blocks = nn.ModuleList(
+ [
+ MMSingleStreamBlock(
+ self.hidden_size,
+ self.heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_act_type=mlp_act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ attn_mode=attn_mode,
+ **factory_kwargs,
+ )
+ for _ in range(mm_single_blocks_depth)
+ ]
+ )
+
+ self.final_layer = FinalLayer(
+ self.hidden_size,
+ self.patch_size,
+ self.out_channels,
+ get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+
+ self.gradient_checkpointing = False
+ self.blocks_to_swap = None
+ self.offloader_double = None
+ self.offloader_single = None
+ self._enable_img_in_txt_in_offloading = False
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ self.txt_in.enable_gradient_checkpointing()
+
+ for block in self.double_blocks + self.single_blocks:
+ block.enable_gradient_checkpointing()
+
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
+
+ def enable_img_in_txt_in_offloading(self):
+ self._enable_img_in_txt_in_offloading = True
+
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
+ self.blocks_to_swap = num_blocks
+ self.num_double_blocks = len(self.double_blocks)
+ self.num_single_blocks = len(self.single_blocks)
+ double_blocks_to_swap = num_blocks // 2
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
+
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
+ )
+
+ self.offloader_double = ModelOffloader(
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
+ )
+ self.offloader_single = ModelOffloader(
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
+ )
+ print(
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
+ )
+
+ def move_to_device_except_swap_blocks(self, device: torch.device):
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
+ if self.blocks_to_swap:
+ save_double_blocks = self.double_blocks
+ save_single_blocks = self.single_blocks
+ self.double_blocks = None
+ self.single_blocks = None
+
+ self.to(device)
+
+ if self.blocks_to_swap:
+ self.double_blocks = save_double_blocks
+ self.single_blocks = save_single_blocks
+
+ def prepare_block_swap_before_forward(self):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
+
+ def enable_deterministic(self):
+ for block in self.double_blocks:
+ block.enable_deterministic()
+ for block in self.single_blocks:
+ block.enable_deterministic()
+
+ def disable_deterministic(self):
+ for block in self.double_blocks:
+ block.disable_deterministic()
+ for block in self.single_blocks:
+ block.disable_deterministic()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor, # Should be in range(0, 1000).
+ text_states: torch.Tensor = None,
+ text_mask: torch.Tensor = None, # Now we don't use it.
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
+ freqs_cos: Optional[torch.Tensor] = None,
+ freqs_sin: Optional[torch.Tensor] = None,
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ out = {}
+ img = x
+ txt = text_states
+ _, _, ot, oh, ow = x.shape
+ tt, th, tw = (
+ ot // self.patch_size[0],
+ oh // self.patch_size[1],
+ ow // self.patch_size[2],
+ )
+
+ # Prepare modulation vectors.
+ vec = self.time_in(t)
+
+ # text modulation
+ vec = vec + self.vector_in(text_states_2)
+
+ # guidance modulation
+ if self.guidance_embed:
+ if guidance is None:
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
+
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
+ vec = vec + self.guidance_in(guidance)
+
+ # Embed image and text.
+ if self._enable_img_in_txt_in_offloading:
+ self.img_in.to(x.device, non_blocking=True)
+ self.txt_in.to(x.device, non_blocking=True)
+ synchronize_device(x.device)
+
+ img = self.img_in(img)
+ if self.text_projection == "linear":
+ txt = self.txt_in(txt)
+ elif self.text_projection == "single_refiner":
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
+ else:
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
+
+ if self._enable_img_in_txt_in_offloading:
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
+ synchronize_device(x.device)
+ clean_memory_on_device(x.device)
+
+ txt_seq_len = txt.shape[1]
+ img_seq_len = img.shape[1]
+
+ # Compute cu_squlens and max_seqlen for flash attention
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
+ cu_seqlens_kv = cu_seqlens_q
+ max_seqlen_q = img_seq_len + txt_seq_len
+ max_seqlen_kv = max_seqlen_q
+
+ attn_mask = None
+ if self.attn_mode == "torch":
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
+ bs = img.shape[0]
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
+
+ # calculate text length and total length
+ text_len = text_mask.sum(dim=1) # (bs, )
+ total_len = img_seq_len + text_len # (bs, )
+
+ # set attention mask
+ for i in range(bs):
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
+
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
+ # --------------------- Pass through DiT blocks ------------------------
+ for block_idx, block in enumerate(self.double_blocks):
+ double_block_args = [
+ img,
+ txt,
+ vec,
+ attn_mask,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ freqs_cis,
+ ]
+
+ if self.blocks_to_swap:
+ self.offloader_double.wait_for_block(block_idx)
+
+ img, txt = block(*double_block_args)
+
+ if self.blocks_to_swap:
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
+
+ # Merge txt and img to pass through single stream blocks.
+ x = torch.cat((img, txt), 1)
+ if self.blocks_to_swap:
+ # delete img, txt to reduce memory usage
+ del img, txt
+ clean_memory_on_device(x.device)
+
+ if len(self.single_blocks) > 0:
+ for block_idx, block in enumerate(self.single_blocks):
+ single_block_args = [
+ x,
+ vec,
+ txt_seq_len,
+ attn_mask,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ (freqs_cos, freqs_sin),
+ ]
+ if self.blocks_to_swap:
+ self.offloader_single.wait_for_block(block_idx)
+
+ x = block(*single_block_args)
+
+ if self.blocks_to_swap:
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
+
+ img = x[:, :img_seq_len, ...]
+ x = None
+
+ # ---------------------------- Final layer ------------------------------
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+
+ img = self.unpatchify(img, tt, th, tw)
+ if return_dict:
+ out["x"] = img
+ return out
+ return img
+
+ def unpatchify(self, x, t, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.unpatchify_channels
+ pt, ph, pw = self.patch_size
+ assert t * h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
+ x = torch.einsum("nthwcopq->nctohpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
+
+ return imgs
+
+ def params_count(self):
+ counts = {
+ "double": sum(
+ [
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
+ + sum(p.numel() for p in block.img_mlp.parameters())
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
+ + sum(p.numel() for p in block.txt_mlp.parameters())
+ for block in self.double_blocks
+ ]
+ ),
+ "single": sum(
+ [
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
+ for block in self.single_blocks
+ ]
+ ),
+ "total": sum(p.numel() for p in self.parameters()),
+ }
+ counts["attn+mlp"] = counts["double"] + counts["single"]
+ return counts
+
+
+#################################################################################
+# HunyuanVideo Configs #
+#################################################################################
+
+HUNYUAN_VIDEO_CONFIG = {
+ "HYVideo-T/2": {
+ "mm_double_blocks_depth": 20,
+ "mm_single_blocks_depth": 40,
+ "rope_dim_list": [16, 56, 56],
+ "hidden_size": 3072,
+ "heads_num": 24,
+ "mlp_width_ratio": 4,
+ },
+ "HYVideo-T/2-cfgdistill": {
+ "mm_double_blocks_depth": 20,
+ "mm_single_blocks_depth": 40,
+ "rope_dim_list": [16, 56, 56],
+ "hidden_size": 3072,
+ "heads_num": 24,
+ "mlp_width_ratio": 4,
+ "guidance_embed": True,
+ },
+}
+
+
+def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
+ """load hunyuan video model
+
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
+
+ Args:
+ text_state_dim (int): text state dimension
+ text_state_dim_2 (int): text state dimension 2
+ in_channels (int): input channels number
+ out_channels (int): output channels number
+ factor_kwargs (dict): factor kwargs
+
+ Returns:
+ model (nn.Module): The hunyuan video model
+ """
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
+ model = HYVideoDiffusionTransformer(
+ text_states_dim=text_states_dim,
+ text_states_dim_2=text_states_dim_2,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
+ **factor_kwargs,
+ )
+ return model
+ # else:
+ # raise NotImplementedError()
+
+
+def load_state_dict(model, model_path):
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
+
+ load_key = "module"
+ if load_key in state_dict:
+ state_dict = state_dict[load_key]
+ else:
+ raise KeyError(
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
+ f"are: {list(state_dict.keys())}."
+ )
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ return model
+
+
+def load_transformer(dit_path, attn_mode, device, dtype) -> HYVideoDiffusionTransformer:
+ # =========================== Build main model ===========================
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode}
+ latent_channels = 16
+ in_channels = latent_channels
+ out_channels = latent_channels
+
+ with accelerate.init_empty_weights():
+ transformer = load_dit_model(
+ text_states_dim=4096,
+ text_states_dim_2=768,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ factor_kwargs=factor_kwargs,
+ )
+
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
+ # loading safetensors: may be already fp8
+ with MemoryEfficientSafeOpen(dit_path) as f:
+ state_dict = {}
+ for k in f.keys():
+ tensor = f.get_tensor(k)
+ tensor = tensor.to(device=device, dtype=dtype)
+ # TODO support comfy model
+ # if k.startswith("model.model."):
+ # k = convert_comfy_model_key(k)
+ state_dict[k] = tensor
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
+ else:
+ transformer = load_state_dict(transformer, dit_path)
+
+ return transformer
+
+
+def get_rotary_pos_embed_by_shape(model, latents_size):
+ target_ndim = 3
+ ndim = 5 - 2
+
+ if isinstance(model.patch_size, int):
+ assert all(s % model.patch_size == 0 for s in latents_size), (
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
+ f"but got {latents_size}."
+ )
+ rope_sizes = [s // model.patch_size for s in latents_size]
+ elif isinstance(model.patch_size, list):
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
+ f"but got {latents_size}."
+ )
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
+
+ if len(rope_sizes) != target_ndim:
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
+ head_dim = model.hidden_size // model.heads_num
+ rope_dim_list = model.rope_dim_list
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+
+ rope_theta = 256
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
+ )
+ return freqs_cos, freqs_sin
+
+
+def get_rotary_pos_embed(vae_name, model, video_length, height, width):
+ # 884
+ if "884" in vae_name:
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
+ elif "888" in vae_name:
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
+ else:
+ latents_size = [video_length, height // 8, width // 8]
+
+ return get_rotary_pos_embed_by_shape(model, latents_size)
diff --git a/hunyuan_model/modulate_layers.py b/hunyuan_model/modulate_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a57c6d2fdc0fca9bf44aeee6996bf1d8a05901
--- /dev/null
+++ b/hunyuan_model/modulate_layers.py
@@ -0,0 +1,76 @@
+from typing import Callable
+
+import torch
+import torch.nn as nn
+
+
+class ModulateDiT(nn.Module):
+ """Modulation layer for DiT."""
+ def __init__(
+ self,
+ hidden_size: int,
+ factor: int,
+ act_layer: Callable,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.act = act_layer()
+ self.linear = nn.Linear(
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear(self.act(x))
+
+
+def modulate(x, shift=None, scale=None):
+ """modulate by shift and scale
+
+ Args:
+ x (torch.Tensor): input tensor.
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
+
+ Returns:
+ torch.Tensor: the output tensor after modulate.
+ """
+ if scale is None and shift is None:
+ return x
+ elif shift is None:
+ return x * (1 + scale.unsqueeze(1))
+ elif scale is None:
+ return x + shift.unsqueeze(1)
+ else:
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+def apply_gate(x, gate=None, tanh=False):
+ """AI is creating summary for apply_gate
+
+ Args:
+ x (torch.Tensor): input tensor.
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
+
+ Returns:
+ torch.Tensor: the output tensor after apply gate.
+ """
+ if gate is None:
+ return x
+ if tanh:
+ return x * gate.unsqueeze(1).tanh()
+ else:
+ return x * gate.unsqueeze(1)
+
+
+def ckpt_wrapper(module):
+ def ckpt_forward(*inputs):
+ outputs = module(*inputs)
+ return outputs
+
+ return ckpt_forward
diff --git a/hunyuan_model/norm_layers.py b/hunyuan_model/norm_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53d167436b6971d3aabf5cfe51c0b9d6dfc022f
--- /dev/null
+++ b/hunyuan_model/norm_layers.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ elementwise_affine=True,
+ eps: float = 1e-6,
+ device=None,
+ dtype=None,
+ ):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.eps = eps
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ if hasattr(self, "weight"):
+ # output = output * self.weight
+ # support fp8
+ output = output * self.weight.to(output.dtype)
+ return output
+
+
+def get_norm_layer(norm_layer):
+ """
+ Get the normalization layer.
+
+ Args:
+ norm_layer (str): The type of normalization layer.
+
+ Returns:
+ norm_layer (nn.Module): The normalization layer.
+ """
+ if norm_layer == "layer":
+ return nn.LayerNorm
+ elif norm_layer == "rms":
+ return RMSNorm
+ else:
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
diff --git a/hunyuan_model/pipeline_hunyuan_video.py b/hunyuan_model/pipeline_hunyuan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1293161e13a47ae7dcedfef2c55e3baefc655f4
--- /dev/null
+++ b/hunyuan_model/pipeline_hunyuan_video.py
@@ -0,0 +1,1100 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+import torch
+import torch.distributed as dist
+import numpy as np
+from dataclasses import dataclass
+from packaging import version
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import BaseOutput
+
+from ...constants import PRECISION_TO_TYPE
+from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from ...text_encoder import TextEncoder
+from ...modules import HYVideoDiffusionTransformer
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """"""
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
+ )
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = (
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ )
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+@dataclass
+class HunyuanVideoPipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+
+class HunyuanVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`TextEncoder`]):
+ Frozen text-encoder.
+ text_encoder_2 ([`TextEncoder`]):
+ Frozen text-encoder_2.
+ transformer ([`HYVideoDiffusionTransformer`]):
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = ["text_encoder_2"]
+ _exclude_from_cpu_offload = ["transformer"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: TextEncoder,
+ transformer: HYVideoDiffusionTransformer,
+ scheduler: KarrasDiffusionSchedulers,
+ text_encoder_2: Optional[TextEncoder] = None,
+ progress_bar_config: Dict[str, Any] = None,
+ args=None,
+ ):
+ super().__init__()
+
+ # ==========================================================================================
+ if progress_bar_config is None:
+ progress_bar_config = {}
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ self._progress_bar_config.update(progress_bar_config)
+
+ self.args = args
+ # ==========================================================================================
+
+ if (
+ hasattr(scheduler.config, "steps_offset")
+ and scheduler.config.steps_offset != 1
+ ):
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate(
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if (
+ hasattr(scheduler.config, "clip_sample")
+ and scheduler.config.clip_sample is True
+ ):
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_attention_mask: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ text_encoder: Optional[TextEncoder] = None,
+ data_type: Optional[str] = "image",
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of videos that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ attention_mask (`torch.Tensor`, *optional*):
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_attention_mask (`torch.Tensor`, *optional*):
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ text_encoder (TextEncoder, *optional*):
+ data_type (`str`, *optional*):
+ """
+ if text_encoder is None:
+ text_encoder = self.text_encoder
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
+ else:
+ scale_lora_layers(text_encoder.model, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
+
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
+
+ if clip_skip is None:
+ prompt_outputs = text_encoder.encode(
+ text_inputs, data_type=data_type, device=device
+ )
+ prompt_embeds = prompt_outputs.hidden_state
+ else:
+ prompt_outputs = text_encoder.encode(
+ text_inputs,
+ output_hidden_states=True,
+ data_type=data_type,
+ device=device,
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
+ prompt_embeds
+ )
+
+ attention_mask = prompt_outputs.attention_mask
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(device)
+ bs_embed, seq_len = attention_mask.shape
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
+ attention_mask = attention_mask.view(
+ bs_embed * num_videos_per_prompt, seq_len
+ )
+
+ if text_encoder is not None:
+ prompt_embeds_dtype = text_encoder.dtype
+ elif self.transformer is not None:
+ prompt_embeds_dtype = self.transformer.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ if prompt_embeds.ndim == 2:
+ bs_embed, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
+ else:
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(
+ bs_embed * num_videos_per_prompt, seq_len, -1
+ )
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(
+ uncond_tokens, text_encoder.tokenizer
+ )
+
+ # max_length = prompt_embeds.shape[1]
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
+
+ negative_prompt_outputs = text_encoder.encode(
+ uncond_input, data_type=data_type, device=device
+ )
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
+
+ negative_attention_mask = negative_prompt_outputs.attention_mask
+ if negative_attention_mask is not None:
+ negative_attention_mask = negative_attention_mask.to(device)
+ _, seq_len = negative_attention_mask.shape
+ negative_attention_mask = negative_attention_mask.repeat(
+ 1, num_videos_per_prompt
+ )
+ negative_attention_mask = negative_attention_mask.view(
+ batch_size * num_videos_per_prompt, seq_len
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=prompt_embeds_dtype, device=device
+ )
+
+ if negative_prompt_embeds.ndim == 2:
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_videos_per_prompt
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_videos_per_prompt, -1
+ )
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_videos_per_prompt, 1
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_videos_per_prompt, seq_len, -1
+ )
+
+ if text_encoder is not None:
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(text_encoder.model, lora_scale)
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ attention_mask,
+ negative_attention_mask,
+ )
+
+ def decode_latents(self, latents, enable_tiling=True):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ if enable_tiling:
+ self.vae.enable_tiling()
+ image = self.vae.decode(latents, return_dict=False)[0]
+ else:
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ if image.ndim == 4:
+ image = image.cpu().permute(0, 2, 3, 1).float()
+ else:
+ image = image.cpu().float()
+ return image
+
+ def prepare_extra_func_kwargs(self, func, kwargs):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ extra_step_kwargs = {}
+
+ for k, v in kwargs.items():
+ accepts = k in set(inspect.signature(func).parameters.keys())
+ if accepts:
+ extra_step_kwargs[k] = v
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ video_length,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ vae_ver="88-4c-sd",
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ )
+
+ if video_length is not None:
+ if "884" in vae_ver:
+ if video_length != 1 and (video_length - 1) % 4 != 0:
+ raise ValueError(
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
+ )
+ elif "888" in vae_ver:
+ if video_length != 1 and (video_length - 1) % 8 != 0:
+ raise ValueError(
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
+ )
+
+ if callback_steps is not None and (
+ not isinstance(callback_steps, int) or callback_steps <= 0
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs
+ for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (
+ not isinstance(prompt, str) and not isinstance(prompt, list)
+ ):
+ raise ValueError(
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ video_length,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(
+ shape, generator=generator, device=device, dtype=dtype
+ )
+ else:
+ latents = latents.to(device)
+
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self,
+ w: torch.Tensor,
+ embedding_dim: int = 512,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ video_length: int,
+ data_type: str = "video",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[
+ Callable[[int, int, Dict], None],
+ PipelineCallback,
+ MultiPipelineCallbacks,
+ ]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ vae_ver: str = "88-4c-sd",
+ enable_tiling: bool = False,
+ n_tokens: Optional[int] = None,
+ embedded_guidance_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`):
+ The height in pixels of the generated image.
+ width (`int`):
+ The width in pixels of the generated image.
+ video_length (`int`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. Default height and width to unet
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ video_length,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ vae_ver=vae_ver,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None)
+ if self.cross_attention_kwargs is not None
+ else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_mask,
+ negative_prompt_mask,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ attention_mask=attention_mask,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_attention_mask=negative_attention_mask,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ data_type=data_type,
+ )
+ if self.text_encoder_2 is not None:
+ (
+ prompt_embeds_2,
+ negative_prompt_embeds_2,
+ prompt_mask_2,
+ negative_prompt_mask_2,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=None,
+ attention_mask=None,
+ negative_prompt_embeds=None,
+ negative_attention_mask=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ text_encoder=self.text_encoder_2,
+ data_type=data_type,
+ )
+ else:
+ prompt_embeds_2 = None
+ negative_prompt_embeds_2 = None
+ prompt_mask_2 = None
+ negative_prompt_mask_2 = None
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if prompt_mask is not None:
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
+ if prompt_embeds_2 is not None:
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
+ if prompt_mask_2 is not None:
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
+
+
+ # 4. Prepare timesteps
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ **extra_set_timesteps_kwargs,
+ )
+
+ if "884" in vae_ver:
+ video_length = (video_length - 1) // 4 + 1
+ elif "888" in vae_ver:
+ video_length = (video_length - 1) // 8 + 1
+ else:
+ video_length = video_length
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
+ self.scheduler.step,
+ {"generator": generator, "eta": eta},
+ )
+
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
+ autocast_enabled = (
+ target_dtype != torch.float32
+ ) and not self.args.disable_autocast
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
+ vae_autocast_enabled = (
+ vae_dtype != torch.float32
+ ) and not self.args.disable_autocast
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ # if is_progress_bar:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * 2)
+ if self.do_classifier_free_guidance
+ else latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+
+ t_expand = t.repeat(latent_model_input.shape[0])
+ guidance_expand = (
+ torch.tensor(
+ [embedded_guidance_scale] * latent_model_input.shape[0],
+ dtype=torch.float32,
+ device=device,
+ ).to(target_dtype)
+ * 1000.0
+ if embedded_guidance_scale is not None
+ else None
+ )
+
+ # predict the noise residual
+ with torch.autocast(
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
+ ):
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
+ latent_model_input, # [2, 16, 33, 24, 42]
+ t_expand, # [2]
+ text_states=prompt_embeds, # [2, 256, 4096]
+ text_mask=prompt_mask, # [2, 256]
+ text_states_2=prompt_embeds_2, # [2, 768]
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
+ guidance=guidance_expand,
+ return_dict=True,
+ )[
+ "x"
+ ]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred,
+ noise_pred_text,
+ guidance_rescale=self.guidance_rescale,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop(
+ "negative_prompt_embeds", negative_prompt_embeds
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ if progress_bar is not None:
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ expand_temporal_dim = False
+ if len(latents.shape) == 4:
+ if isinstance(self.vae, AutoencoderKLCausal3D):
+ latents = latents.unsqueeze(2)
+ expand_temporal_dim = True
+ elif len(latents.shape) == 5:
+ pass
+ else:
+ raise ValueError(
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
+ )
+
+ if (
+ hasattr(self.vae.config, "shift_factor")
+ and self.vae.config.shift_factor
+ ):
+ latents = (
+ latents / self.vae.config.scaling_factor
+ + self.vae.config.shift_factor
+ )
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ with torch.autocast(
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
+ ):
+ if enable_tiling:
+ self.vae.enable_tiling()
+ image = self.vae.decode(
+ latents, return_dict=False, generator=generator
+ )[0]
+ else:
+ image = self.vae.decode(
+ latents, return_dict=False, generator=generator
+ )[0]
+
+ if expand_temporal_dim or image.shape[2] == 1:
+ image = image.squeeze(2)
+
+ else:
+ image = latents
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().float()
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return image
+
+ return HunyuanVideoPipelineOutput(videos=image)
diff --git a/hunyuan_model/posemb_layers.py b/hunyuan_model/posemb_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfce82c690540d17a55a51b7997ee7ceb0bdbf44
--- /dev/null
+++ b/hunyuan_model/posemb_layers.py
@@ -0,0 +1,310 @@
+import torch
+from typing import Union, Tuple, List
+
+
+def _to_tuple(x, dim=2):
+ if isinstance(x, int):
+ return (x,) * dim
+ elif len(x) == dim:
+ return x
+ else:
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
+
+
+def get_meshgrid_nd(start, *args, dim=2):
+ """
+ Get n-D meshgrid with start, stop and num.
+
+ Args:
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
+ n-tuples.
+ *args: See above.
+ dim (int): Dimension of the meshgrid. Defaults to 2.
+
+ Returns:
+ grid (np.ndarray): [dim, ...]
+ """
+ if len(args) == 0:
+ # start is grid_size
+ num = _to_tuple(start, dim=dim)
+ start = (0,) * dim
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = _to_tuple(start, dim=dim)
+ stop = _to_tuple(args[0], dim=dim)
+ num = [stop[i] - start[i] for i in range(dim)]
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
+ axis_grid = []
+ for i in range(dim):
+ a, b, n = start[i], stop[i], num[i]
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
+ axis_grid.append(g)
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
+
+ return grid
+
+
+#################################################################################
+# Rotary Positional Embedding Functions #
+#################################################################################
+# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
+
+
+def reshape_for_broadcast(
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ x: torch.Tensor,
+ head_first=False,
+):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Notes:
+ When using FlashMHAModified, head_first should be False.
+ When using Attention, head_first should be True.
+
+ Args:
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+
+ Raises:
+ AssertionError: If the frequency tensor doesn't match the expected shape.
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+ """
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+
+ if isinstance(freqs_cis, tuple):
+ # freqs_cis: (cos, sin) in real space
+ if head_first:
+ assert freqs_cis[0].shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [
+ d if i == ndim - 2 or i == ndim - 1 else 1
+ for i, d in enumerate(x.shape)
+ ]
+ else:
+ assert freqs_cis[0].shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
+ else:
+ # freqs_cis: values in complex space
+ if head_first:
+ assert freqs_cis.shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [
+ d if i == ndim - 2 or i == ndim - 1 else 1
+ for i, d in enumerate(x.shape)
+ ]
+ else:
+ assert freqs_cis.shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def rotate_half(x):
+ x_real, x_imag = (
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
+ ) # [B, S, H, D//2]
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ head_first: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+ returned as real tensors.
+
+ Args:
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+ """
+ xk_out = None
+ if isinstance(freqs_cis, tuple):
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
+ # real * cos - imag * sin
+ # imag * cos + real * sin
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
+ else:
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
+ xq_ = torch.view_as_complex(
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
+ ) # [B, S, H, D//2]
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
+ xq.device
+ ) # [S, D//2] --> [1, S, 1, D//2]
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
+ xk_ = torch.view_as_complex(
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
+ ) # [B, S, H, D//2]
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
+
+ return xq_out, xk_out
+
+
+def get_nd_rotary_pos_embed(
+ rope_dim_list,
+ start,
+ *args,
+ theta=10000.0,
+ use_real=False,
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
+ interpolation_factor: Union[float, List[float]] = 1.0,
+):
+ """
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
+
+ Args:
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
+ sum(rope_dim_list) should equal to head_dim of attention layer.
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
+ *args: See above.
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
+ part and an imaginary part separately.
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
+
+ Returns:
+ pos_embed (torch.Tensor): [HW, D/2]
+ """
+
+ grid = get_meshgrid_nd(
+ start, *args, dim=len(rope_dim_list)
+ ) # [3, W, H, D] / [2, W, H]
+
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
+ assert len(theta_rescale_factor) == len(
+ rope_dim_list
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
+
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
+ assert len(interpolation_factor) == len(
+ rope_dim_list
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
+
+ # use 1/ndim of dimensions to encode grid_axis
+ embs = []
+ for i in range(len(rope_dim_list)):
+ emb = get_1d_rotary_pos_embed(
+ rope_dim_list[i],
+ grid[i].reshape(-1),
+ theta,
+ use_real=use_real,
+ theta_rescale_factor=theta_rescale_factor[i],
+ interpolation_factor=interpolation_factor[i],
+ ) # 2 x [WHD, rope_dim_list[i]]
+ embs.append(emb)
+
+ if use_real:
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[torch.FloatTensor, int],
+ theta: float = 10000.0,
+ use_real: bool = False,
+ theta_rescale_factor: float = 1.0,
+ interpolation_factor: float = 1.0,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
+
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool, optional): If True, return real part and imaginary part separately.
+ Otherwise, return complex numbers.
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
+
+ Returns:
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
+ """
+ if isinstance(pos, int):
+ pos = torch.arange(pos).float()
+
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+ # has some connection to NTK literature
+ if theta_rescale_factor != 1.0:
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
+
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ ) # [D/2]
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(
+ torch.ones_like(freqs), freqs
+ ) # complex64 # [S, D/2]
+ return freqs_cis
diff --git a/hunyuan_model/text_encoder.py b/hunyuan_model/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a8bb74ce5f94a11b8fbe6021a6bb7321381e023
--- /dev/null
+++ b/hunyuan_model/text_encoder.py
@@ -0,0 +1,438 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
+from transformers.utils import ModelOutput
+from transformers.models.llama import LlamaModel
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+# When using decoder-only models, we must provide a prompt template to instruct the text encoder
+# on how to generate the text.
+# --------------------------------------------------------------------
+PROMPT_TEMPLATE_ENCODE = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+)
+PROMPT_TEMPLATE_ENCODE_VIDEO = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+)
+
+NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
+
+PROMPT_TEMPLATE = {
+ "dit-llm-encode": {
+ "template": PROMPT_TEMPLATE_ENCODE,
+ "crop_start": 36,
+ },
+ "dit-llm-encode-video": {
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
+ "crop_start": 95,
+ },
+}
+
+
+def use_default(value, default):
+ return value if value is not None else default
+
+
+def load_text_encoder(
+ text_encoder_type: str,
+ text_encoder_path: str,
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
+):
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
+
+ # reduce peak memory usage by specifying the dtype of the model
+ dtype = text_encoder_dtype
+ if text_encoder_type == "clipL":
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
+ elif text_encoder_type == "llm":
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
+ text_encoder.final_layer_norm = text_encoder.norm
+ else:
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
+ # from_pretrained will ensure that the model is in eval mode.
+
+ if dtype is not None:
+ text_encoder = text_encoder.to(dtype=dtype)
+
+ text_encoder.requires_grad_(False)
+
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
+ return text_encoder, text_encoder_path
+
+
+def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
+
+ if tokenizer_type == "clipL":
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
+ elif tokenizer_type == "llm":
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side)
+ else:
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
+
+ return tokenizer, tokenizer_path
+
+
+@dataclass
+class TextEncoderModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
+ List of decoded texts.
+ """
+
+ hidden_state: torch.FloatTensor = None
+ attention_mask: Optional[torch.LongTensor] = None
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
+ text_outputs: Optional[list] = None
+
+
+class TextEncoder(nn.Module):
+ def __init__(
+ self,
+ text_encoder_type: str,
+ max_length: int,
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
+ text_encoder_path: Optional[str] = None,
+ tokenizer_type: Optional[str] = None,
+ tokenizer_path: Optional[str] = None,
+ output_key: Optional[str] = None,
+ use_attention_mask: bool = True,
+ input_max_length: Optional[int] = None,
+ prompt_template: Optional[dict] = None,
+ prompt_template_video: Optional[dict] = None,
+ hidden_state_skip_layer: Optional[int] = None,
+ apply_final_norm: bool = False,
+ reproduce: bool = False,
+ ):
+ super().__init__()
+ self.text_encoder_type = text_encoder_type
+ self.max_length = max_length
+ # self.precision = text_encoder_precision
+ self.model_path = text_encoder_path
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
+ self.use_attention_mask = use_attention_mask
+ if prompt_template_video is not None:
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
+ self.prompt_template = prompt_template
+ self.prompt_template_video = prompt_template_video
+ self.hidden_state_skip_layer = hidden_state_skip_layer
+ self.apply_final_norm = apply_final_norm
+ self.reproduce = reproduce
+
+ self.use_template = self.prompt_template is not None
+ if self.use_template:
+ assert (
+ isinstance(self.prompt_template, dict) and "template" in self.prompt_template
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
+ assert "{}" in str(self.prompt_template["template"]), (
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
+ f"got {self.prompt_template['template']}"
+ )
+
+ self.use_video_template = self.prompt_template_video is not None
+ if self.use_video_template:
+ if self.prompt_template_video is not None:
+ assert (
+ isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
+ assert "{}" in str(self.prompt_template_video["template"]), (
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
+ f"got {self.prompt_template_video['template']}"
+ )
+
+ if "t5" in text_encoder_type:
+ self.output_key = output_key or "last_hidden_state"
+ elif "clip" in text_encoder_type:
+ self.output_key = output_key or "pooler_output"
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
+ self.output_key = output_key or "last_hidden_state"
+ else:
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
+
+ self.model, self.model_path = load_text_encoder(
+ text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
+ )
+ self.dtype = self.model.dtype
+
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
+ tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
+ )
+
+ def __repr__(self):
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
+
+ @property
+ def device(self):
+ return self.model.device
+
+ @staticmethod
+ def apply_text_to_template(text, template, prevent_empty_text=True):
+ """
+ Apply text to template.
+
+ Args:
+ text (str): Input text.
+ template (str or list): Template string or list of chat conversation.
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
+ by adding a space. Defaults to True.
+ """
+ if isinstance(template, str):
+ # Will send string to tokenizer. Used for llm
+ return template.format(text)
+ else:
+ raise TypeError(f"Unsupported template type: {type(template)}")
+
+ def text2tokens(self, text, data_type="image"):
+ """
+ Tokenize the input text.
+
+ Args:
+ text (str or list): Input text.
+ """
+ tokenize_input_type = "str"
+ if self.use_template:
+ if data_type == "image":
+ prompt_template = self.prompt_template["template"]
+ elif data_type == "video":
+ prompt_template = self.prompt_template_video["template"]
+ else:
+ raise ValueError(f"Unsupported data type: {data_type}")
+ if isinstance(text, (list, tuple)):
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
+ if isinstance(text[0], list):
+ tokenize_input_type = "list"
+ elif isinstance(text, str):
+ text = self.apply_text_to_template(text, prompt_template)
+ if isinstance(text, list):
+ tokenize_input_type = "list"
+ else:
+ raise TypeError(f"Unsupported text type: {type(text)}")
+
+ kwargs = dict(
+ truncation=True,
+ max_length=self.max_length,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ if tokenize_input_type == "str":
+ return self.tokenizer(
+ text,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ **kwargs,
+ )
+ elif tokenize_input_type == "list":
+ return self.tokenizer.apply_chat_template(
+ text,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ **kwargs,
+ )
+ else:
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
+
+ def encode(
+ self,
+ batch_encoding,
+ use_attention_mask=None,
+ output_hidden_states=False,
+ do_sample=None,
+ hidden_state_skip_layer=None,
+ return_texts=False,
+ data_type="image",
+ device=None,
+ ):
+ """
+ Args:
+ batch_encoding (dict): Batch encoding from tokenizer.
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
+ Defaults to None.
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
+ output_hidden_states will be set True. Defaults to False.
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
+ When self.produce is False, do_sample is set to True by default.
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
+ If None, self.output_key will be used. Defaults to None.
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
+ """
+ device = self.model.device if device is None else device
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
+ do_sample = use_default(do_sample, not self.reproduce)
+ attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
+ outputs = self.model(
+ input_ids=batch_encoding["input_ids"].to(device),
+ attention_mask=attention_mask,
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
+ )
+ if hidden_state_skip_layer is not None:
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
+ # Real last hidden state already has layer norm applied. So here we only apply it
+ # for intermediate layers.
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
+ else:
+ last_hidden_state = outputs[self.output_key]
+
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
+ if self.use_template:
+ if data_type == "image":
+ crop_start = self.prompt_template.get("crop_start", -1)
+ elif data_type == "video":
+ crop_start = self.prompt_template_video.get("crop_start", -1)
+ else:
+ raise ValueError(f"Unsupported data type: {data_type}")
+ if crop_start > 0:
+ last_hidden_state = last_hidden_state[:, crop_start:]
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
+
+ if output_hidden_states:
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
+
+ def forward(
+ self,
+ text,
+ use_attention_mask=None,
+ output_hidden_states=False,
+ do_sample=False,
+ hidden_state_skip_layer=None,
+ return_texts=False,
+ ):
+ batch_encoding = self.text2tokens(text)
+ return self.encode(
+ batch_encoding,
+ use_attention_mask=use_attention_mask,
+ output_hidden_states=output_hidden_states,
+ do_sample=do_sample,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ return_texts=return_texts,
+ )
+
+
+# region HunyanVideo architecture
+
+
+def load_text_encoder_1(
+ text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
+) -> TextEncoder:
+ text_encoder_dtype = dtype or torch.float16
+ text_encoder_type = "llm"
+ text_len = 256
+ hidden_state_skip_layer = 2
+ apply_final_norm = False
+ reproduce = False
+
+ prompt_template = "dit-llm-encode"
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
+ prompt_template_video = "dit-llm-encode-video"
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
+
+ crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
+ max_length = text_len + crop_start
+
+ text_encoder_1 = TextEncoder(
+ text_encoder_type=text_encoder_type,
+ max_length=max_length,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=text_encoder_dir,
+ tokenizer_type=text_encoder_type,
+ prompt_template=prompt_template,
+ prompt_template_video=prompt_template_video,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ apply_final_norm=apply_final_norm,
+ reproduce=reproduce,
+ )
+ text_encoder_1.eval()
+
+ if fp8_llm:
+ org_dtype = text_encoder_1.dtype
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
+ text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
+
+ # prepare LLM for fp8
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
+ def forward_hook(module):
+ def forward(hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
+
+ return forward
+
+ for module in llama_model.modules():
+ if module.__class__.__name__ in ["Embedding"]:
+ # print("set", module.__class__.__name__, "to", target_dtype)
+ module.to(target_dtype)
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
+ # print("set", module.__class__.__name__, "hooks")
+ module.forward = forward_hook(module)
+
+ prepare_fp8(text_encoder_1.model, org_dtype)
+ else:
+ text_encoder_1.to(device=device)
+
+ return text_encoder_1
+
+
+def load_text_encoder_2(
+ text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
+) -> TextEncoder:
+ text_encoder_dtype = dtype or torch.float16
+ reproduce = False
+
+ text_encoder_2_type = "clipL"
+ text_len_2 = 77
+
+ text_encoder_2 = TextEncoder(
+ text_encoder_type=text_encoder_2_type,
+ max_length=text_len_2,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=text_encoder_dir,
+ tokenizer_type=text_encoder_2_type,
+ reproduce=reproduce,
+ )
+ text_encoder_2.eval()
+
+ text_encoder_2.to(device=device)
+
+ return text_encoder_2
+
+
+# endregion
diff --git a/hunyuan_model/token_refiner.py b/hunyuan_model/token_refiner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f641762fe0da65e7d9037062a282afe860260d14
--- /dev/null
+++ b/hunyuan_model/token_refiner.py
@@ -0,0 +1,236 @@
+from typing import Optional
+
+from einops import rearrange
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .activation_layers import get_activation_layer
+from .attention import attention
+from .norm_layers import get_norm_layer
+from .embed_layers import TimestepEmbedder, TextProjection
+from .mlp_layers import MLP
+from .modulate_layers import modulate, apply_gate
+
+
+class IndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ heads_num,
+ mlp_width_ratio: str = 4.0,
+ mlp_drop_rate: float = 0.0,
+ act_type: str = "silu",
+ qk_norm: bool = False,
+ qk_norm_type: str = "layer",
+ qkv_bias: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.self_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.self_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
+ act_layer = get_activation_layer(act_type)
+ self.mlp = MLP(
+ in_channels=hidden_size,
+ hidden_channels=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=mlp_drop_rate,
+ **factory_kwargs,
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ act_layer(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ self.gradient_checkpointing = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
+ attn_mask: torch.Tensor = None,
+ ):
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
+
+ norm_x = self.norm1(x)
+ qkv = self.self_attn_qkv(norm_x)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ # Apply QK-Norm if needed
+ q = self.self_attn_q_norm(q).to(v)
+ k = self.self_attn_k_norm(k).to(v)
+
+ # Self-Attention
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
+
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
+
+ # FFN Layer
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
+
+ return x
+
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+
+class IndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ heads_num,
+ depth,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ act_type: str = "silu",
+ qk_norm: bool = False,
+ qk_norm_type: str = "layer",
+ qkv_bias: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.blocks = nn.ModuleList(
+ [
+ IndividualTokenRefinerBlock(
+ hidden_size=hidden_size,
+ heads_num=heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ act_type=act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ qkv_bias=qkv_bias,
+ **factory_kwargs,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ def enable_gradient_checkpointing(self):
+ for block in self.blocks:
+ block.enable_gradient_checkpointing()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ c: torch.LongTensor,
+ mask: Optional[torch.Tensor] = None,
+ ):
+ self_attn_mask = None
+ if mask is not None:
+ batch_size = mask.shape[0]
+ seq_len = mask.shape[1]
+ mask = mask.to(x.device)
+ # batch_size x 1 x seq_len x seq_len
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ # batch_size x 1 x seq_len x seq_len
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+ # avoids self-attention weight being NaN for padding tokens
+ self_attn_mask[:, :, :, 0] = True
+
+ for block in self.blocks:
+ x = block(x, c, self_attn_mask)
+ return x
+
+
+class SingleTokenRefiner(nn.Module):
+ """
+ A single token refiner block for llm text embedding refine.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ heads_num,
+ depth,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ act_type: str = "silu",
+ qk_norm: bool = False,
+ qk_norm_type: str = "layer",
+ qkv_bias: bool = True,
+ attn_mode: str = "torch",
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.attn_mode = attn_mode
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
+
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
+
+ act_layer = get_activation_layer(act_type)
+ # Build timestep embedding layer
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
+ # Build context embedding layer
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
+
+ self.individual_token_refiner = IndividualTokenRefiner(
+ hidden_size=hidden_size,
+ heads_num=heads_num,
+ depth=depth,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ act_type=act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ qkv_bias=qkv_bias,
+ **factory_kwargs,
+ )
+
+ def enable_gradient_checkpointing(self):
+ self.individual_token_refiner.enable_gradient_checkpointing()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.LongTensor,
+ mask: Optional[torch.LongTensor] = None,
+ ):
+ timestep_aware_representations = self.t_embedder(t)
+
+ if mask is None:
+ context_aware_representations = x.mean(dim=1)
+ else:
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ context_aware_representations = self.c_embedder(context_aware_representations)
+ c = timestep_aware_representations + context_aware_representations
+
+ x = self.input_embedder(x)
+
+ x = self.individual_token_refiner(x, c, mask)
+
+ return x
diff --git a/hunyuan_model/vae.py b/hunyuan_model/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..2535809cdd782a6bb69726546ae971723f48da54
--- /dev/null
+++ b/hunyuan_model/vae.py
@@ -0,0 +1,442 @@
+from dataclasses import dataclass
+import json
+from typing import Optional, Tuple, Union
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from diffusers.utils import BaseOutput, is_torch_version
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.models.attention_processor import SpatialNorm
+from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+SCALING_FACTOR = 0.476986
+VAE_VER = "884-16c-hy"
+
+
+def load_vae(
+ vae_type: str = "884-16c-hy",
+ vae_dtype: Optional[Union[str, torch.dtype]] = None,
+ sample_size: tuple = None,
+ vae_path: str = None,
+ device=None,
+):
+ """the fucntion to load the 3D VAE model
+
+ Args:
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
+ sample_size (tuple, optional): the tiling size. Defaults to None.
+ vae_path (str, optional): the path to vae. Defaults to None.
+ logger (_type_, optional): logger. Defaults to None.
+ device (_type_, optional): device to load vae. Defaults to None.
+ """
+ if vae_path is None:
+ vae_path = VAE_PATH[vae_type]
+
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
+
+ # use fixed config for Hunyuan's VAE
+ CONFIG_JSON = """{
+ "_class_name": "AutoencoderKLCausal3D",
+ "_diffusers_version": "0.4.2",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlockCausal3D",
+ "DownEncoderBlockCausal3D",
+ "DownEncoderBlockCausal3D",
+ "DownEncoderBlockCausal3D"
+ ],
+ "in_channels": 3,
+ "latent_channels": 16,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "sample_tsize": 64,
+ "up_block_types": [
+ "UpDecoderBlockCausal3D",
+ "UpDecoderBlockCausal3D",
+ "UpDecoderBlockCausal3D",
+ "UpDecoderBlockCausal3D"
+ ],
+ "scaling_factor": 0.476986,
+ "time_compression_ratio": 4,
+ "mid_block_add_attention": true
+ }"""
+
+ # config = AutoencoderKLCausal3D.load_config(vae_path)
+ config = json.loads(CONFIG_JSON)
+
+ # import here to avoid circular import
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+
+ if sample_size:
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
+ else:
+ vae = AutoencoderKLCausal3D.from_config(config)
+
+ # vae_ckpt = Path(vae_path) / "pytorch_model.pt"
+ # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
+
+ ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ if any(k.startswith("vae.") for k in ckpt.keys()):
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
+ vae.load_state_dict(ckpt)
+
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
+ time_compression_ratio = vae.config.time_compression_ratio
+
+ if vae_dtype is not None:
+ vae = vae.to(vae_dtype)
+
+ vae.requires_grad_(False)
+
+ logger.info(f"VAE to dtype: {vae.dtype}")
+
+ if device is not None:
+ vae = vae.to(device)
+
+ vae.eval()
+
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ r"""
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The decoded output sample from the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class EncoderCausal3D(nn.Module):
+ r"""
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ mid_block_add_attention=True,
+ time_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
+
+ if time_compression_ratio == 4:
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
+ else:
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
+
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
+ down_block = get_down_block3d(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
+ downsample_stride=downsample_stride,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockCausal3D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ add_attention=mid_block_add_attention,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
+
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ r"""The forward method of the `EncoderCausal3D` class."""
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
+
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class DecoderCausal3D(nn.Module):
+ r"""
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ norm_type: str = "group", # group, spatial
+ mid_block_add_attention=True,
+ time_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ temb_channels = in_channels if norm_type == "spatial" else None
+
+ # mid
+ self.mid_block = UNetMidBlockCausal3D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ add_attention=mid_block_add_attention,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
+
+ if time_compression_ratio == 4:
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
+ else:
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
+
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
+ up_block = get_up_block3d(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
+ upsample_scale_factor=upsample_scale_factor,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=temb_channels,
+ resnet_time_scale_shift=norm_type,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_type == "spatial":
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ latent_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""The forward method of the `DecoderCausal3D` class."""
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
+
+ sample = self.conv_in(sample)
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ else:
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ else:
+ # middle
+ sample = self.mid_block(sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample, latent_embeds)
+
+ # post-process
+ if latent_embeds is None:
+ sample = self.conv_norm_out(sample)
+ else:
+ sample = self.conv_norm_out(sample, latent_embeds)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ if parameters.ndim == 3:
+ dim = 2 # (B, L, C)
+ elif parameters.ndim == 5 or parameters.ndim == 4:
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
+ else:
+ raise NotImplementedError
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = randn_tensor(
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ reduce_dim = list(range(1, self.mean.ndim))
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=reduce_dim,
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=reduce_dim,
+ )
+
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self) -> torch.Tensor:
+ return self.mean
diff --git a/hv_generate_video.py b/hv_generate_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc15e27cf4147798550aec92c95a0901623db852
--- /dev/null
+++ b/hv_generate_video.py
@@ -0,0 +1,563 @@
+import argparse
+from datetime import datetime
+from pathlib import Path
+import random
+import sys
+import os
+import time
+from typing import Optional, Union
+
+import numpy as np
+import torch
+import torchvision
+import accelerate
+from diffusers.utils.torch_utils import randn_tensor
+from transformers.models.llama import LlamaModel
+from tqdm import tqdm
+import av
+from einops import rearrange
+from safetensors.torch import load_file
+
+from hunyuan_model import vae
+from hunyuan_model.text_encoder import TextEncoder
+from hunyuan_model.text_encoder import PROMPT_TEMPLATE
+from hunyuan_model.vae import load_vae
+from hunyuan_model.models import load_transformer, get_rotary_pos_embed
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+from networks import lora
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def clean_memory_on_device(device):
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ elif device.type == "cpu":
+ pass
+ elif device.type == "mps": # not tested
+ torch.mps.empty_cache()
+
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
+ """save videos by video tensor
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
+
+ Args:
+ videos (torch.Tensor): video tensor predicted by the model
+ path (str): path to save video
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
+ n_rows (int, optional): Defaults to 1.
+ fps (int, optional): video save fps. Defaults to 8.
+ """
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = torch.clamp(x, 0, 1)
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ # # save video with av
+ # container = av.open(path, "w")
+ # stream = container.add_stream("libx264", rate=fps)
+ # for x in outputs:
+ # frame = av.VideoFrame.from_ndarray(x, format="rgb24")
+ # packet = stream.encode(frame)
+ # container.mux(packet)
+ # packet = stream.encode(None)
+ # container.mux(packet)
+ # container.close()
+
+ height, width, _ = outputs[0].shape
+
+ # create output container
+ container = av.open(path, mode="w")
+
+ # create video stream
+ codec = "libx264"
+ pixel_format = "yuv420p"
+ stream = container.add_stream(codec, rate=fps)
+ stream.width = width
+ stream.height = height
+ stream.pix_fmt = pixel_format
+ stream.bit_rate = 4000000 # 4Mbit/s
+
+ for frame_array in outputs:
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
+ packets = stream.encode(frame)
+ for packet in packets:
+ container.mux(packet)
+
+ for packet in stream.encode():
+ container.mux(packet)
+
+ container.close()
+
+
+# region Encoding prompt
+
+
+def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of videos that should be generated per prompt
+ text_encoder (TextEncoder):
+ text encoder to be used for encoding the prompt
+ """
+ # LoRA and Textual Inversion are not supported in this script
+ # negative prompt and prompt embedding are not supported in this script
+ # clip_skip is not supported in this script because it is not used in the original script
+ data_type = "video" # video only, image is not supported
+
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
+
+ with torch.no_grad():
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
+ prompt_embeds = prompt_outputs.hidden_state
+
+ attention_mask = prompt_outputs.attention_mask
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(device)
+ bs_embed, seq_len = attention_mask.shape
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
+
+ prompt_embeds_dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ if prompt_embeds.ndim == 2:
+ bs_embed, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
+ else:
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, attention_mask
+
+
+def encode_input_prompt(prompt, args, device, fp8_llm=False, accelerator=None):
+ # constants
+ prompt_template_video = "dit-llm-encode-video"
+ prompt_template = "dit-llm-encode"
+ text_encoder_dtype = torch.float16
+ text_encoder_type = "llm"
+ text_len = 256
+ hidden_state_skip_layer = 2
+ apply_final_norm = False
+ reproduce = False
+
+ text_encoder_2_type = "clipL"
+ text_len_2 = 77
+
+ num_videos = 1
+
+ # if args.prompt_template_video is not None:
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
+ # elif args.prompt_template is not None:
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
+ # else:
+ # crop_start = 0
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
+ max_length = text_len + crop_start
+
+ # prompt_template
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
+
+ # prompt_template_video
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
+
+ # load text encoders
+ logger.info(f"loading text encoder: {args.text_encoder1}")
+ text_encoder = TextEncoder(
+ text_encoder_type=text_encoder_type,
+ max_length=max_length,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=args.text_encoder1,
+ tokenizer_type=text_encoder_type,
+ prompt_template=prompt_template,
+ prompt_template_video=prompt_template_video,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ apply_final_norm=apply_final_norm,
+ reproduce=reproduce,
+ )
+ text_encoder.eval()
+ if fp8_llm:
+ org_dtype = text_encoder.dtype
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
+ text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
+
+ # prepare LLM for fp8
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
+ def forward_hook(module):
+ def forward(hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
+
+ return forward
+
+ for module in llama_model.modules():
+ if module.__class__.__name__ in ["Embedding"]:
+ # print("set", module.__class__.__name__, "to", target_dtype)
+ module.to(target_dtype)
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
+ # print("set", module.__class__.__name__, "hooks")
+ module.forward = forward_hook(module)
+
+ prepare_fp8(text_encoder.model, org_dtype)
+
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
+ text_encoder_2 = TextEncoder(
+ text_encoder_type=text_encoder_2_type,
+ max_length=text_len_2,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=args.text_encoder2,
+ tokenizer_type=text_encoder_2_type,
+ reproduce=reproduce,
+ )
+ text_encoder_2.eval()
+
+ # encode prompt
+ logger.info(f"Encoding prompt with text encoder 1")
+ text_encoder.to(device=device)
+ if fp8_llm:
+ with accelerator.autocast():
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
+ else:
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
+ text_encoder = None
+ clean_memory_on_device(device)
+
+ logger.info(f"Encoding prompt with text encoder 2")
+ text_encoder_2.to(device=device)
+ prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
+
+ prompt_embeds = prompt_embeds.to("cpu")
+ prompt_mask = prompt_mask.to("cpu")
+ prompt_embeds_2 = prompt_embeds_2.to("cpu")
+ prompt_mask_2 = prompt_mask_2.to("cpu")
+
+ text_encoder_2 = None
+ clean_memory_on_device(device)
+
+ return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
+
+
+# endregion
+
+
+def decode_latents(args, latents, device):
+ vae_dtype = torch.float16
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
+ vae.eval()
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
+
+ # set chunk_size to CausalConv3d recursively
+ chunk_size = args.vae_chunk_size
+ if chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
+
+ expand_temporal_dim = False
+ if len(latents.shape) == 4:
+ latents = latents.unsqueeze(2)
+ expand_temporal_dim = True
+ elif len(latents.shape) == 5:
+ pass
+ else:
+ raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
+
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
+ else:
+ latents = latents / vae.config.scaling_factor
+
+ latents = latents.to(device=device, dtype=vae.dtype)
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ # elif args.vae_tiling:
+ else:
+ vae.enable_spatial_tiling(True)
+ with torch.no_grad():
+ image = vae.decode(latents, return_dict=False)[0]
+
+ if expand_temporal_dim or image.shape[2] == 1:
+ image = image.squeeze(2)
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().float()
+
+ return image
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
+
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
+ parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
+
+ # LoRA
+ parser.add_argument("--lora_weight", type=str, required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, default=1.0, help="LoRA multiplier")
+
+ parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
+ parser.add_argument("--video_length", type=int, default=129, help="video length")
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
+ parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
+
+ # Flow Matching
+ parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
+
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
+ parser.add_argument(
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
+ )
+ parser.add_argument(
+ "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "sdpa"], help="attention mode"
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
+ parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
+ parser.add_argument("--output_type", type=str, default="video", help="output type: video, latent or both")
+ parser.add_argument("--latent_path", type=str, default=None, help="path to latent for decode. no inference")
+
+ args = parser.parse_args()
+
+ assert args.latent_path is None or args.output_type == "video", "latent-path is only supported with output-type=video"
+
+ # update dit_weight based on model_base if not exists
+
+ return args
+
+
+def check_inputs(args):
+ height = args.video_size[0]
+ width = args.video_size[1]
+ video_length = args.video_length
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ return height, width, video_length
+
+
+def main():
+ args = parse_args()
+
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ dit_dtype = torch.bfloat16
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
+ logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
+
+ if args.latent_path is not None:
+ latents = torch.load(args.latent_path, map_location="cpu")
+ logger.info(f"Loaded latent from {args.latent_path}. Shape: {latents.shape}")
+ latents = latents.unsqueeze(0)
+ seeds = [0] # dummy seed
+ else:
+ # prepare accelerator
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
+
+ # load prompt
+ prompt = args.prompt # TODO load prompts from file
+ assert prompt is not None, "prompt is required"
+
+ # check inputs: may be height, width, video_length etc will be changed for each generation in future
+ height, width, video_length = check_inputs(args)
+
+ # encode prompt with LLM and Text Encoder
+ logger.info(f"Encoding prompt: {prompt}")
+ prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
+ prompt, args, device, args.fp8_llm, accelerator
+ )
+
+ # load DiT model
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
+ loading_device = "cpu" if blocks_to_swap > 0 else device
+
+ logger.info(f"Loading DiT model from {args.dit}")
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch"
+ transformer = load_transformer(args.dit, args.attn_mode, loading_device, dit_dtype)
+ transformer.eval()
+
+ # load LoRA weights
+ if args.lora_weight is not None:
+ logger.info(f"Loading LoRA weights from {args.lora_weight}")
+ weights_sd = load_file(args.lora_weight)
+ network = lora.create_network_from_weights_hunyuan_video(
+ args.lora_multiplier, weights_sd, unet=transformer, for_inference=True
+ )
+ logger.info("Merging LoRA weights to DiT model")
+ network.merge_to(None, transformer, weights_sd, device=device)
+ logger.info("LoRA weights loaded")
+
+ if blocks_to_swap > 0:
+ logger.info(f"Casting model to {dit_weight_dtype}")
+ transformer.to(dtype=dit_weight_dtype)
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
+ transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
+ transformer.move_to_device_except_swap_blocks(device)
+ transformer.prepare_block_swap_before_forward()
+ else:
+ logger.info(f"Moving and casting model to {device} and {dit_weight_dtype}")
+ transformer.to(device=device, dtype=dit_weight_dtype)
+ if args.img_in_txt_in_offloading:
+ logger.info("Enable offloading img_in and txt_in to CPU")
+ transformer.enable_img_in_txt_in_offloading()
+
+ # load scheduler
+ logger.info(f"Loading scheduler")
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
+
+ # Prepare timesteps
+ num_inference_steps = args.infer_steps
+ scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
+ timesteps = scheduler.timesteps
+
+ # Prepare generator
+ num_videos_per_prompt = 1 # args.num_videos
+ seed = args.seed
+ if seed is None:
+ seeds = [random.randint(0, 1_000_000) for _ in range(num_videos_per_prompt)]
+ elif isinstance(seed, int):
+ seeds = [seed + i for i in range(num_videos_per_prompt)]
+ else:
+ raise ValueError(f"Seed must be an integer or None, got {seed}.")
+ generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
+
+ # Prepare latents
+ num_channels_latents = 16 # transformer.config.in_channels
+ vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
+
+ vae_ver = vae.VAE_VER
+ if "884" in vae_ver:
+ latent_video_length = (video_length - 1) // 4 + 1
+ elif "888" in vae_ver:
+ latent_video_length = (video_length - 1) // 8 + 1
+ else:
+ latent_video_length = video_length
+
+ shape = (
+ num_videos_per_prompt,
+ num_channels_latents,
+ latent_video_length,
+ height // vae_scale_factor,
+ width // vae_scale_factor,
+ )
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
+ # FlowMatchDiscreteScheduler does not have init_noise_sigma
+
+ # Denoising loop
+ embedded_guidance_scale = args.embedded_cfg_scale
+ if embedded_guidance_scale is not None:
+ guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
+ guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
+ else:
+ guidance_expand = None
+ freqs_cos, freqs_sin = get_rotary_pos_embed(vae.VAE_VER, transformer, video_length, height, width)
+ # n_tokens = freqs_cos.shape[0]
+
+ # move and cast all inputs to the correct device and dtype
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
+ prompt_mask = prompt_mask.to(device=device)
+ prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
+ prompt_mask_2 = prompt_mask_2.to(device=device)
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
+ # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
+ with tqdm(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latents = scheduler.scale_model_input(latents, t)
+
+ # predict the noise residual
+ with torch.no_grad(), accelerator.autocast():
+ noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
+ latents, # [1, 16, 33, 24, 42]
+ t.repeat(latents.shape[0]).to(device=device, dtype=dit_dtype), # [1]
+ text_states=prompt_embeds, # [1, 256, 4096]
+ text_mask=prompt_mask, # [1, 256]
+ text_states_2=prompt_embeds_2, # [1, 768]
+ freqs_cos=freqs_cos, # [seqlen, head_dim]
+ freqs_sin=freqs_sin, # [seqlen, head_dim]
+ guidance=guidance_expand,
+ return_dict=True,
+ )["x"]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
+ if progress_bar is not None:
+ progress_bar.update()
+
+ # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
+ # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
+
+ latents = latents.detach().cpu()
+ transformer = None
+ clean_memory_on_device(device)
+
+ # Save samples
+ output_type = args.output_type
+ save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ if output_type == "latent" or output_type == "both":
+ # save latent
+ for i, latent in enumerate(latents):
+ latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.pt"
+ torch.save(latent, latent_path)
+ logger.info(f"Latent save to: {latent_path}")
+ if output_type == "video" or output_type == "both":
+ # save video
+ videos = decode_latents(args, latents, device)
+ for i, sample in enumerate(videos):
+ sample = sample.unsqueeze(0)
+ save_path = f"{save_path}/{time_flag}_{seeds[i]}.mp4"
+ save_videos_grid(sample, save_path, fps=24)
+ logger.info(f"Sample save to: {save_path}")
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/hv_train_network.py b/hv_train_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c9e30d8123611dba167f32bdc58bb61be67491
--- /dev/null
+++ b/hv_train_network.py
@@ -0,0 +1,2129 @@
+import ast
+import asyncio
+from datetime import datetime
+import gc
+import importlib
+import argparse
+import math
+import os
+import pathlib
+import re
+import sys
+import random
+import time
+import json
+from multiprocessing import Value
+from typing import Any, Dict, List, Optional
+import accelerate
+import numpy as np
+from packaging.version import Version
+
+import huggingface_hub
+import toml
+
+import torch
+from tqdm import tqdm
+from accelerate.utils import set_seed
+from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
+from safetensors.torch import load_file
+import transformers
+from diffusers.optimization import (
+ SchedulerType as DiffusersSchedulerType,
+ TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
+)
+from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
+
+from dataset import config_utils
+from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
+import hunyuan_model.text_encoder as text_encoder_module
+from hunyuan_model.vae import load_vae
+import hunyuan_model.vae as vae_module
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+import networks.lora as lora_module
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+
+import logging
+
+from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
+
+SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
+SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
+SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
+SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
+SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
+
+SS_METADATA_MINIMUM_KEYS = [
+ SS_METADATA_KEY_BASE_MODEL_VERSION,
+ SS_METADATA_KEY_NETWORK_MODULE,
+ SS_METADATA_KEY_NETWORK_DIM,
+ SS_METADATA_KEY_NETWORK_ALPHA,
+ SS_METADATA_KEY_NETWORK_ARGS,
+]
+
+
+def clean_memory_on_device(device: torch.device):
+ r"""
+ Clean memory on the specified device, will be called from training scripts.
+ """
+ gc.collect()
+
+ # device may "cuda" or "cuda:0", so we need to check the type of device
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ if device.type == "xpu":
+ torch.xpu.empty_cache()
+ if device.type == "mps":
+ torch.mps.empty_cache()
+
+
+# for collate_fn: epoch and step is multiprocessing.Value
+class collator_class:
+ def __init__(self, epoch, step, dataset):
+ self.current_epoch = epoch
+ self.current_step = step
+ self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
+
+ def __call__(self, examples):
+ worker_info = torch.utils.data.get_worker_info()
+ # worker_info is None in the main process
+ if worker_info is not None:
+ dataset = worker_info.dataset
+ else:
+ dataset = self.dataset
+
+ # set epoch and step
+ dataset.set_current_epoch(self.current_epoch.value)
+ dataset.set_current_step(self.current_step.value)
+ return examples[0]
+
+
+def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
+ """
+ DeepSpeed is not supported in this script currently.
+ """
+ if args.logging_dir is None:
+ logging_dir = None
+ else:
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
+
+ if args.log_with is None:
+ if logging_dir is not None:
+ log_with = "tensorboard"
+ else:
+ log_with = None
+ else:
+ log_with = args.log_with
+ if log_with in ["tensorboard", "all"]:
+ if logging_dir is None:
+ raise ValueError(
+ "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
+ )
+ if log_with in ["wandb", "all"]:
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError("No wandb / wandb がインストールされていないようです")
+ if logging_dir is not None:
+ os.makedirs(logging_dir, exist_ok=True)
+ os.environ["WANDB_DIR"] = logging_dir
+ if args.wandb_api_key is not None:
+ wandb.login(key=args.wandb_api_key)
+
+ kwargs_handlers = [
+ (
+ InitProcessGroupKwargs(
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
+ init_method=(
+ "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
+ ),
+ timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
+ )
+ if torch.cuda.device_count() > 1
+ else None
+ ),
+ (
+ DistributedDataParallelKwargs(
+ gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
+ )
+ if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
+ else None
+ ),
+ ]
+ kwargs_handlers = [i for i in kwargs_handlers if i is not None]
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=log_with,
+ project_dir=logging_dir,
+ kwargs_handlers=kwargs_handlers,
+ )
+ print("accelerator device:", accelerator.device)
+ return accelerator
+
+
+def line_to_prompt_dict(line: str) -> dict:
+ # subset of gen_img_diffusers
+ prompt_args = line.split(" --")
+ prompt_dict = {}
+ prompt_dict["prompt"] = prompt_args[0]
+
+ for parg in prompt_args:
+ try:
+ m = re.match(r"w (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["width"] = int(m.group(1))
+ continue
+
+ m = re.match(r"h (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["height"] = int(m.group(1))
+ continue
+
+ m = re.match(r"f (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["frame_count"] = int(m.group(1))
+ continue
+
+ m = re.match(r"d (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["seed"] = int(m.group(1))
+ continue
+
+ m = re.match(r"s (\d+)", parg, re.IGNORECASE)
+ if m: # steps
+ prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
+ continue
+
+ # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
+ # if m: # scale
+ # prompt_dict["scale"] = float(m.group(1))
+ # continue
+ # m = re.match(r"n (.+)", parg, re.IGNORECASE)
+ # if m: # negative prompt
+ # prompt_dict["negative_prompt"] = m.group(1)
+ # continue
+
+ except ValueError as ex:
+ logger.error(f"Exception in parsing / 解析エラー: {parg}")
+ logger.error(ex)
+
+ return prompt_dict
+
+
+def load_prompts(prompt_file: str) -> list[Dict]:
+ # read prompts
+ if prompt_file.endswith(".txt"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+ prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
+ elif prompt_file.endswith(".toml"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ data = toml.load(f)
+ prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
+ elif prompt_file.endswith(".json"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ prompts = json.load(f)
+
+ # preprocess prompts
+ for i in range(len(prompts)):
+ prompt_dict = prompts[i]
+ if isinstance(prompt_dict, str):
+ prompt_dict = line_to_prompt_dict(prompt_dict)
+ prompts[i] = prompt_dict
+ assert isinstance(prompt_dict, dict)
+
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
+ prompt_dict["enum"] = i
+ prompt_dict.pop("subset", None)
+
+ return prompts
+
+
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
+):
+ """Compute the density for sampling the timesteps when doing SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device="cpu")
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device="cpu")
+ return u
+
+
+def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+
+ # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
+ if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
+ # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
+ # round to nearest timestep
+ logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
+ step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
+ else:
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+
+def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
+ """Computes loss weighting scheme for SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
+ if weighting_scheme == "sigma_sqrt":
+ weighting = (sigmas**-2.0).float()
+ else:
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
+ weighting = 2 / (math.pi * bot)
+ else:
+ weighting = None # torch.ones_like(sigmas)
+ return weighting
+
+
+class NetworkTrainer:
+ def __init__(self):
+ pass
+
+ # TODO 他のスクリプトと共通化する
+ def generate_step_logs(
+ self,
+ args: argparse.Namespace,
+ current_loss,
+ avr_loss,
+ lr_scheduler,
+ lr_descriptions,
+ optimizer=None,
+ keys_scaled=None,
+ mean_norm=None,
+ maximum_norm=None,
+ ):
+ network_train_unet_only = True
+ logs = {"loss/current": current_loss, "loss/average": avr_loss}
+
+ if keys_scaled is not None:
+ logs["max_norm/keys_scaled"] = keys_scaled
+ logs["max_norm/average_key_norm"] = mean_norm
+ logs["max_norm/max_key_norm"] = maximum_norm
+
+ lrs = lr_scheduler.get_last_lr()
+ for i, lr in enumerate(lrs):
+ if lr_descriptions is not None:
+ lr_desc = lr_descriptions[i]
+ else:
+ idx = i - (0 if network_train_unet_only else -1)
+ if idx == -1:
+ lr_desc = "textencoder"
+ else:
+ if len(lrs) > 2:
+ lr_desc = f"group{idx}"
+ else:
+ lr_desc = "unet"
+
+ logs[f"lr/{lr_desc}"] = lr
+
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
+ # tracking d*lr value
+ logs[f"lr/d*lr/{lr_desc}"] = (
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
+ )
+ if (
+ args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
+ ): # tracking d*lr value of unet.
+ logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
+ else:
+ idx = 0
+ if not network_train_unet_only:
+ logs["lr/textencoder"] = float(lrs[0])
+ idx = 1
+
+ for i in range(idx, len(lrs)):
+ logs[f"lr/group{i}"] = float(lrs[i])
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
+ logs[f"lr/d*lr/group{i}"] = (
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
+ )
+ if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
+ logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
+
+ return logs
+
+ def process_sample_prompts(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ sample_prompts: str,
+ text_encoder1: str,
+ text_encoder2: str,
+ fp8_llm: bool,
+ ):
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
+ prompts = load_prompts(sample_prompts)
+
+ def encode_for_text_encoder(text_encoder):
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
+ with accelerator.autocast(), torch.no_grad():
+ for prompt_dict in prompts:
+ for p in [prompt_dict.get("prompt", "")]:
+ if p not in sample_prompts_te_outputs:
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
+
+ data_type = "video"
+ text_inputs = text_encoder.text2tokens(p, data_type=data_type)
+
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
+ sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
+
+ return sample_prompts_te_outputs
+
+ # Load Text Encoder 1 and encode
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
+ logger.info(f"loading text encoder 1: {text_encoder1}")
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
+
+ logger.info("encoding with Text Encoder 1")
+ te_outputs_1 = encode_for_text_encoder(text_encoder_1)
+ del text_encoder_1
+
+ # Load Text Encoder 2 and encode
+ logger.info(f"loading text encoder 2: {text_encoder2}")
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
+
+ logger.info("encoding with Text Encoder 2")
+ te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
+ del text_encoder_2
+
+ # prepare sample parameters
+ sample_parameters = []
+ for prompt_dict in prompts:
+ prompt_dict_copy = prompt_dict.copy()
+ p = prompt_dict.get("prompt", "")
+ prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
+ prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
+ prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
+ prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
+ sample_parameters.append(prompt_dict_copy)
+
+ clean_memory_on_device(accelerator.device)
+
+ return sample_parameters
+
+ def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
+ # adamw, adamw8bit, adafactor
+
+ optimizer_type = args.optimizer_type
+
+ # split optimizer_type and optimizer_args
+ optimizer_kwargs = {}
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
+ for arg in args.optimizer_args:
+ key, value = arg.split("=")
+ value = ast.literal_eval(value)
+ optimizer_kwargs[key] = value
+
+ lr = args.learning_rate
+ optimizer = None
+ optimizer_class = None
+
+ if optimizer_type.endswith("8bit".lower()):
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
+
+ if optimizer_type == "AdamW8bit".lower():
+ logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
+ optimizer_class = bnb.optim.AdamW8bit
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ elif optimizer_type == "Adafactor".lower():
+ # Adafactor: check relative_step and warmup_init
+ if "relative_step" not in optimizer_kwargs:
+ optimizer_kwargs["relative_step"] = True # default
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
+ logger.info(
+ f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
+ )
+ optimizer_kwargs["relative_step"] = True
+ logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
+
+ if optimizer_kwargs["relative_step"]:
+ logger.info(f"relative_step is true / relative_stepがtrueです")
+ if lr != 0.0:
+ logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
+ args.learning_rate = None
+
+ if args.lr_scheduler != "adafactor":
+ logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
+
+ lr = None
+ else:
+ if args.max_grad_norm != 0.0:
+ logger.warning(
+ f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
+ )
+ if args.lr_scheduler != "constant_with_warmup":
+ logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
+ logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
+
+ optimizer_class = transformers.optimization.Adafactor
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ elif optimizer_type == "AdamW".lower():
+ logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
+ optimizer_class = torch.optim.AdamW
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ if optimizer is None:
+ # 任意のoptimizerを使う
+ case_sensitive_optimizer_type = args.optimizer_type # not lower
+ logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
+
+ if "." not in case_sensitive_optimizer_type: # from torch.optim
+ optimizer_module = torch.optim
+ else: # from other library
+ values = case_sensitive_optimizer_type.split(".")
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
+ case_sensitive_optimizer_type = values[-1]
+
+ optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ # for logging
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
+
+ # get train and eval functions
+ if hasattr(optimizer, "train") and callable(optimizer.train):
+ train_fn = optimizer.train
+ eval_fn = optimizer.eval
+ else:
+ train_fn = lambda: None
+ eval_fn = lambda: None
+
+ return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
+
+ def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
+ return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
+
+ def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
+ # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
+ # this scheduler is used for logging only.
+ # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
+ class DummyScheduler:
+ def __init__(self, optimizer: torch.optim.Optimizer):
+ self.optimizer = optimizer
+
+ def step(self):
+ pass
+
+ def get_last_lr(self):
+ return [group["lr"] for group in self.optimizer.param_groups]
+
+ return DummyScheduler(optimizer)
+
+ def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
+ """
+ Unified API to get any scheduler from its name.
+ """
+ # if schedulefree optimizer, return dummy scheduler
+ if self.is_schedulefree_optimizer(optimizer, args):
+ return self.get_dummy_scheduler(optimizer)
+
+ name = args.lr_scheduler
+ num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
+ num_warmup_steps: Optional[int] = (
+ int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
+ )
+ num_decay_steps: Optional[int] = (
+ int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
+ )
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
+ num_cycles = args.lr_scheduler_num_cycles
+ power = args.lr_scheduler_power
+ timescale = args.lr_scheduler_timescale
+ min_lr_ratio = args.lr_scheduler_min_lr_ratio
+
+ lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
+ if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
+ for arg in args.lr_scheduler_args:
+ key, value = arg.split("=")
+ value = ast.literal_eval(value)
+ lr_scheduler_kwargs[key] = value
+
+ def wrap_check_needless_num_warmup_steps(return_vals):
+ if num_warmup_steps is not None and num_warmup_steps != 0:
+ raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
+ return return_vals
+
+ # using any lr_scheduler from other library
+ if args.lr_scheduler_type:
+ lr_scheduler_type = args.lr_scheduler_type
+ logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
+ if "." not in lr_scheduler_type: # default to use torch.optim
+ lr_scheduler_module = torch.optim.lr_scheduler
+ else:
+ values = lr_scheduler_type.split(".")
+ lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
+ lr_scheduler_type = values[-1]
+ lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
+ lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
+ return lr_scheduler
+
+ if name.startswith("adafactor"):
+ assert (
+ type(optimizer) == transformers.optimization.Adafactor
+ ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
+ initial_lr = float(name.split(":")[1])
+ # logger.info(f"adafactor scheduler init lr {initial_lr}")
+ return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
+
+ if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
+ name = DiffusersSchedulerType(name)
+ schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
+ return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
+
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+ if name == SchedulerType.CONSTANT:
+ return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
+
+ if name == SchedulerType.INVERSE_SQRT:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ **lr_scheduler_kwargs,
+ )
+
+ if name == SchedulerType.POLYNOMIAL:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ power=power,
+ **lr_scheduler_kwargs,
+ )
+
+ if name == SchedulerType.COSINE_WITH_MIN_LR:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles / 2,
+ min_lr_rate=min_lr_ratio,
+ **lr_scheduler_kwargs,
+ )
+
+ # these schedulers do not require `num_decay_steps`
+ if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **lr_scheduler_kwargs,
+ )
+
+ # All other schedulers require `num_decay_steps`
+ if num_decay_steps is None:
+ raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_stable_steps=num_stable_steps,
+ num_decay_steps=num_decay_steps,
+ num_cycles=num_cycles / 2,
+ min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
+ **lr_scheduler_kwargs,
+ )
+
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_decay_steps=num_decay_steps,
+ **lr_scheduler_kwargs,
+ )
+
+ def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
+ if not args.resume:
+ return False
+
+ if not args.resume_from_huggingface:
+ logger.info(f"resume training from local state: {args.resume}")
+ accelerator.load_state(args.resume)
+ return True
+
+ logger.info(f"resume training from huggingface state: {args.resume}")
+ repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
+ path_in_repo = "/".join(args.resume.split("/")[2:])
+ revision = None
+ repo_type = None
+ if ":" in path_in_repo:
+ divided = path_in_repo.split(":")
+ if len(divided) == 2:
+ path_in_repo, revision = divided
+ repo_type = "model"
+ else:
+ path_in_repo, revision, repo_type = divided
+ logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
+
+ list_files = huggingface_utils.list_dir(
+ repo_id=repo_id,
+ subfolder=path_in_repo,
+ revision=revision,
+ token=args.huggingface_token,
+ repo_type=repo_type,
+ )
+
+ async def download(filename) -> str:
+ def task():
+ return huggingface_hub.hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ revision=revision,
+ repo_type=repo_type,
+ token=args.huggingface_token,
+ )
+
+ return await asyncio.get_event_loop().run_in_executor(None, task)
+
+ loop = asyncio.get_event_loop()
+ results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
+ if len(results) == 0:
+ raise ValueError(
+ "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
+ )
+ dirname = os.path.dirname(results[0])
+ accelerator.load_state(dirname)
+
+ return True
+
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
+ pass
+
+ def get_noisy_model_input_and_timesteps(
+ self,
+ args: argparse.Namespace,
+ noise: torch.Tensor,
+ latents: torch.Tensor,
+ noise_scheduler: FlowMatchDiscreteScheduler,
+ device: torch.device,
+ dtype: torch.dtype,
+ ):
+ batch_size = noise.shape[0]
+
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
+ # Simple random t-based noise sampling
+ if args.timestep_sampling == "sigmoid":
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
+ else:
+ t = torch.rand((batch_size,), device=device)
+
+ elif args.timestep_sampling == "shift":
+ shift = args.discrete_flow_shift
+ logits_norm = torch.randn(batch_size, device=device)
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
+ t = logits_norm.sigmoid()
+ t = (t * shift) / (1 + (shift - 1) * t)
+
+ t_min = args.min_timestep if args.min_timestep is not None else 0
+ t_max = args.max_timestep if args.max_timestep is not None else 1000.0
+ t_min /= 1000.0
+ t_max /= 1000.0
+ t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
+
+ timesteps = t * 1000.0
+ t = t.view(-1, 1, 1, 1, 1)
+ noisy_model_input = (1 - t) * latents + t * noise
+
+ timesteps += 1 # 1 to 1000
+ else:
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=batch_size,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ # indices = (u * noise_scheduler.config.num_train_timesteps).long()
+ t_min = args.min_timestep if args.min_timestep is not None else 0
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
+ indices = (u * (t_max - t_min) + t_min).long()
+
+ timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
+
+ return noisy_model_input, timesteps
+
+ def show_timesteps(self, args: argparse.Namespace):
+ N_TRY = 100000
+ BATCH_SIZE = 1000
+ CONSOLE_WIDTH = 64
+ N_TIMESTEPS_PER_LINE = 25
+
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
+ # print(f"Noise scheduler timesteps: {noise_scheduler.timesteps}")
+
+ latents = torch.zeros(BATCH_SIZE, 1, 1, 1, 1, dtype=torch.float16)
+ noise = torch.ones_like(latents)
+
+ # sample timesteps
+ sampled_timesteps = [0] * noise_scheduler.config.num_train_timesteps
+ for i in tqdm(range(N_TRY // BATCH_SIZE)):
+ # we use noise=1, so retured noisy_model_input is same as timestep, because `noisy_model_input = (1 - t) * latents + t * noise`
+ actual_timesteps, _ = self.get_noisy_model_input_and_timesteps(
+ args, noise, latents, noise_scheduler, "cpu", torch.float16
+ )
+ actual_timesteps = actual_timesteps[:, 0, 0, 0, 0] * 1000
+ for t in actual_timesteps:
+ t = int(t.item())
+ sampled_timesteps[t] += 1
+
+ # sample weighting
+ sampled_weighting = [0] * noise_scheduler.config.num_train_timesteps
+ for i in tqdm(range(len(sampled_weighting))):
+ timesteps = torch.tensor([i + 1], device="cpu")
+ weighting = compute_loss_weighting_for_sd3(args.weighting_scheme, noise_scheduler, timesteps, "cpu", torch.float16)
+ if weighting is None:
+ weighting = torch.tensor(1.0, device="cpu")
+ elif torch.isinf(weighting).any():
+ weighting = torch.tensor(1.0, device="cpu")
+ sampled_weighting[i] = weighting.item()
+
+ # show results
+ if args.show_timesteps == "image":
+ # show timesteps with matplotlib
+ import matplotlib.pyplot as plt
+
+ plt.figure(figsize=(10, 5))
+ plt.subplot(1, 2, 1)
+ plt.bar(range(len(sampled_timesteps)), sampled_timesteps, width=1.0)
+ plt.title("Sampled timesteps")
+ plt.xlabel("Timestep")
+ plt.ylabel("Count")
+
+ plt.subplot(1, 2, 2)
+ plt.bar(range(len(sampled_weighting)), sampled_weighting, width=1.0)
+ plt.title("Sampled loss weighting")
+ plt.xlabel("Timestep")
+ plt.ylabel("Weighting")
+
+ plt.tight_layout()
+ plt.show()
+
+ else:
+ sampled_timesteps = np.array(sampled_timesteps)
+ sampled_weighting = np.array(sampled_weighting)
+
+ # average per line
+ sampled_timesteps = sampled_timesteps.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1)
+ sampled_weighting = sampled_weighting.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1)
+
+ max_count = max(sampled_timesteps)
+ print(f"Sampled timesteps: max count={max_count}")
+ for i, t in enumerate(sampled_timesteps):
+ line = f"{(i)*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: "
+ line += "#" * int(t / max_count * CONSOLE_WIDTH)
+ print(line)
+
+ max_weighting = max(sampled_weighting)
+ print(f"Sampled loss weighting: max weighting={max_weighting}")
+ for i, w in enumerate(sampled_weighting):
+ line = f"{i*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: {w:8.2f} "
+ line += "#" * int(w / max_weighting * CONSOLE_WIDTH)
+ print(line)
+
+ def train(self, args):
+ # show timesteps for debugging
+ if args.show_timesteps:
+ self.show_timesteps(args)
+ return
+
+ session_id = random.randint(0, 2**32)
+ training_started_at = time.time()
+ # setup_logging(args, reset=True)
+
+ if args.seed is None:
+ args.seed = random.randint(0, 2**32)
+ set_seed(args.seed)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
+
+ current_epoch = Value("i", 0)
+ current_step = Value("i", 0)
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
+ collator = collator_class(current_epoch, current_step, ds_for_collator)
+
+ # prepare accelerator
+ logger.info("preparing accelerator")
+ accelerator = prepare_accelerator(args)
+ is_main_process = accelerator.is_main_process
+
+ # prepare dtype
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # HunyuanVideo specific
+ dit_dtype = torch.bfloat16 if args.dit_dtype is None else model_utils.str_to_dtype(args.dit_dtype)
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8_base else dit_dtype
+ logger.info(f"DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
+ vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
+
+ # get embedding for sampling images
+ sample_parameters = vae = None
+ if args.sample_prompts:
+ sample_parameters = self.process_sample_prompts(
+ args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
+ )
+
+ # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
+ vae.requires_grad_(False)
+ vae.eval()
+
+ if args.vae_chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ elif args.vae_tiling:
+ vae.enable_spatial_tiling(True)
+
+ # load DiT model
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
+ loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
+
+ logger.info(f"Loading DiT model from {args.dit}")
+ if args.sdpa:
+ attn_mode = "torch"
+ elif args.flash_attn:
+ attn_mode = "flash"
+ elif args.sage_attn:
+ attn_mode = "sageattn"
+ else:
+ raise ValueError(
+ f"either --sdpa or --flash-attn or --sage-attn must be specified / --sdpaか--flash-attnか--sage-attnのいずれかを指定してください"
+ )
+ transformer = load_transformer(args.dit, attn_mode, loading_device, dit_weight_dtype)
+ transformer.eval()
+ transformer.requires_grad_(False)
+
+ if blocks_to_swap > 0:
+ logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
+ transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
+ transformer.move_to_device_except_swap_blocks(accelerator.device)
+ if args.img_in_txt_in_offloading:
+ logger.info("Enable offloading img_in and txt_in to CPU")
+ transformer.enable_img_in_txt_in_offloading()
+
+ # load network model for differential training
+ sys.path.append(os.path.dirname(__file__))
+ accelerator.print("import network module:", args.network_module)
+ network_module: lora_module = importlib.import_module(args.network_module) # actual module may be different
+
+ if args.base_weights is not None:
+ # if base_weights is specified, merge the weights to DiT model
+ for i, weight_path in enumerate(args.base_weights):
+ if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
+ multiplier = 1.0
+ else:
+ multiplier = args.base_weights_multiplier[i]
+
+ accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
+
+ weights_sd = load_file(weight_path)
+ module = network_module.create_network_from_weights_hunyuan_video(multiplier, weights_sd, unet=transformer)
+ module.merge_to(None, transformer, weights_sd, weight_dtype, "cpu")
+
+ accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
+
+ # prepare network
+ net_kwargs = {}
+ if args.network_args is not None:
+ for net_arg in args.network_args:
+ key, value = net_arg.split("=")
+ net_kwargs[key] = value
+
+ if args.dim_from_weights:
+ logger.info(f"Loading network from weights: {args.dim_from_weights}")
+ weights_sd = load_file(args.dim_from_weights)
+ network, _ = network_module.create_network_from_weights_hunyuan_video(1, weights_sd, unet=transformer)
+ else:
+ network = network_module.create_network_hunyuan_video(
+ 1.0,
+ args.network_dim,
+ args.network_alpha,
+ vae,
+ None,
+ transformer,
+ neuron_dropout=args.network_dropout,
+ **net_kwargs,
+ )
+ if network is None:
+ return
+
+ network.prepare_network(args)
+
+ # apply network to DiT
+ network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
+
+ if args.network_weights is not None:
+ # FIXME consider alpha of weights: this assumes that the alpha is not changed
+ info = network.load_weights(args.network_weights)
+ accelerator.print(f"load network weights from {args.network_weights}: {info}")
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ network.enable_gradient_checkpointing() # may have no effect
+
+ # prepare optimizer, data loader etc.
+ accelerator.print("prepare optimizer, data loader etc.")
+
+ trainable_params, lr_descriptions = network.prepare_optimizer_params(unet_lr=args.learning_rate)
+ optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
+ args, trainable_params
+ )
+
+ # prepare dataloader
+
+ # num workers for data loader: if 0, persistent_workers is not available
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset_group,
+ batch_size=1,
+ shuffle=True,
+ collate_fn=collator,
+ num_workers=n_workers,
+ persistent_workers=args.persistent_data_loader_workers,
+ )
+
+ # calculate max_train_steps
+ if args.max_train_epochs is not None:
+ args.max_train_steps = args.max_train_epochs * math.ceil(
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
+ )
+ accelerator.print(
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
+ )
+
+ # send max_train_steps to train_dataset_group
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
+
+ # prepare lr_scheduler
+ lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
+
+ # prepare training model. accelerator does some magic here
+
+ # experimental feature: train the model with gradients in fp16/bf16
+ network_dtype = torch.float32
+ args.full_fp16 = args.full_bf16 = False # temporary disabled because stochastic rounding is not supported yet
+ if args.full_fp16:
+ assert (
+ args.mixed_precision == "fp16"
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
+ accelerator.print("enable full fp16 training.")
+ network_dtype = weight_dtype
+ network.to(network_dtype)
+ elif args.full_bf16:
+ assert (
+ args.mixed_precision == "bf16"
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
+ accelerator.print("enable full bf16 training.")
+ network_dtype = weight_dtype
+ network.to(network_dtype)
+
+ if dit_weight_dtype != dit_dtype:
+ logger.info(f"casting model to {dit_weight_dtype}")
+ transformer.to(dit_weight_dtype)
+
+ if blocks_to_swap > 0:
+ transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
+ accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
+ accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
+ else:
+ transformer = accelerator.prepare(transformer)
+
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
+ training_model = network
+
+ if args.gradient_checkpointing:
+ transformer.train()
+ else:
+ transformer.eval()
+
+ accelerator.unwrap_model(network).prepare_grad_etc(transformer)
+
+ if args.full_fp16:
+ # patch accelerator for fp16 training
+ # def patch_accelerator_for_fp16_training(accelerator):
+ org_unscale_grads = accelerator.scaler._unscale_grads_
+
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
+
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
+
+ # before resuming make hook for saving/loading to save/load the network weights only
+ def save_model_hook(models, weights, output_dir):
+ # pop weights of other models than network to save only network weights
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
+ if accelerator.is_main_process: # or args.deepspeed:
+ remove_indices = []
+ for i, model in enumerate(models):
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
+ remove_indices.append(i)
+ for i in reversed(remove_indices):
+ if len(weights) > i:
+ weights.pop(i)
+ # print(f"save model hook: {len(weights)} weights will be saved")
+
+ def load_model_hook(models, input_dir):
+ # remove models except network
+ remove_indices = []
+ for i, model in enumerate(models):
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
+ remove_indices.append(i)
+ for i in reversed(remove_indices):
+ models.pop(i)
+ # print(f"load model hook: {len(models)} models will be loaded")
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # resume from local or huggingface. accelerator.step is set
+ self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
+
+ # epoch数を計算する
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # 学習する
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ accelerator.print("running training / 学習開始")
+ accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
+ accelerator.print(
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
+ )
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
+
+ # TODO refactor metadata creation and move to util
+ metadata = {
+ "ss_" "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
+ "ss_training_started_at": training_started_at, # unix timestamp
+ "ss_output_name": args.output_name,
+ "ss_learning_rate": args.learning_rate,
+ "ss_num_train_items": train_dataset_group.num_train_items,
+ "ss_num_batches_per_epoch": len(train_dataloader),
+ "ss_num_epochs": num_train_epochs,
+ "ss_gradient_checkpointing": args.gradient_checkpointing,
+ "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
+ "ss_max_train_steps": args.max_train_steps,
+ "ss_lr_warmup_steps": args.lr_warmup_steps,
+ "ss_lr_scheduler": args.lr_scheduler,
+ SS_METADATA_KEY_BASE_MODEL_VERSION: BASE_MODEL_VERSION_HUNYUAN_VIDEO,
+ # "ss_network_module": args.network_module,
+ # "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
+ # "ss_network_alpha": args.network_alpha, # some networks may not have alpha
+ SS_METADATA_KEY_NETWORK_MODULE: args.network_module,
+ SS_METADATA_KEY_NETWORK_DIM: args.network_dim,
+ SS_METADATA_KEY_NETWORK_ALPHA: args.network_alpha,
+ "ss_network_dropout": args.network_dropout, # some networks may not have dropout
+ "ss_mixed_precision": args.mixed_precision,
+ "ss_seed": args.seed,
+ "ss_training_comment": args.training_comment, # will not be updated after training
+ # "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
+ "ss_max_grad_norm": args.max_grad_norm,
+ "ss_fp8_base": bool(args.fp8_base),
+ "ss_fp8_llm": bool(args.fp8_llm),
+ "ss_full_fp16": bool(args.full_fp16),
+ "ss_full_bf16": bool(args.full_bf16),
+ "ss_weighting_scheme": args.weighting_scheme,
+ "ss_logit_mean": args.logit_mean,
+ "ss_logit_std": args.logit_std,
+ "ss_mode_scale": args.mode_scale,
+ "ss_guidance_scale": args.guidance_scale,
+ "ss_timestep_sampling": args.timestep_sampling,
+ "ss_sigmoid_scale": args.sigmoid_scale,
+ "ss_discrete_flow_shift": args.discrete_flow_shift,
+ }
+
+ datasets_metadata = []
+ # tag_frequency = {} # merge tag frequency for metadata editor # TODO support tag frequency
+ for dataset in train_dataset_group.datasets:
+ dataset_metadata = dataset.get_metadata()
+ datasets_metadata.append(dataset_metadata)
+
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
+
+ # add extra args
+ if args.network_args:
+ # metadata["ss_network_args"] = json.dumps(net_kwargs)
+ metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(net_kwargs)
+
+ # model name and hash
+ if args.dit is not None:
+ logger.info(f"calculate hash for DiT model: {args.dit}")
+ sd_model_name = args.dit
+ if os.path.exists(sd_model_name):
+ metadata["ss_sd_model_hash"] = model_utils.model_hash(sd_model_name)
+ metadata["ss_new_sd_model_hash"] = model_utils.calculate_sha256(sd_model_name)
+ sd_model_name = os.path.basename(sd_model_name)
+ metadata["ss_sd_model_name"] = sd_model_name
+
+ if args.vae is not None:
+ logger.info(f"calculate hash for VAE model: {args.vae}")
+ vae_name = args.vae
+ if os.path.exists(vae_name):
+ metadata["ss_vae_hash"] = model_utils.model_hash(vae_name)
+ metadata["ss_new_vae_hash"] = model_utils.calculate_sha256(vae_name)
+ vae_name = os.path.basename(vae_name)
+ metadata["ss_vae_name"] = vae_name
+
+ metadata = {k: str(v) for k, v in metadata.items()}
+
+ # make minimum metadata for filtering
+ minimum_metadata = {}
+ for key in SS_METADATA_MINIMUM_KEYS:
+ if key in metadata:
+ minimum_metadata[key] = metadata[key]
+
+ if accelerator.is_main_process:
+ init_kwargs = {}
+ if args.wandb_run_name:
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
+ if args.log_tracker_config is not None:
+ init_kwargs = toml.load(args.log_tracker_config)
+ accelerator.init_trackers(
+ "network_train" if args.log_tracker_name is None else args.log_tracker_name,
+ config=train_utils.get_sanitized_config_or_none(args),
+ init_kwargs=init_kwargs,
+ )
+
+ # TODO skip until initial step
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
+
+ epoch_to_start = 0
+ global_step = 0
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
+
+ loss_recorder = train_utils.LossRecorder()
+ del train_dataset_group
+
+ # function for saving/removing
+ save_dtype = dit_dtype
+
+ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
+ os.makedirs(args.output_dir, exist_ok=True)
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
+
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
+ metadata["ss_training_finished_at"] = str(time.time())
+ metadata["ss_steps"] = str(steps)
+ metadata["ss_epoch"] = str(epoch_no)
+
+ metadata_to_save = minimum_metadata if args.no_metadata else metadata
+
+ title = args.metadata_title if args.metadata_title is not None else args.output_name
+ if args.min_timestep is not None or args.max_timestep is not None:
+ min_time_step = args.min_timestep if args.min_timestep is not None else 0
+ max_time_step = args.max_timestep if args.max_timestep is not None else 1000
+ md_timesteps = (min_time_step, max_time_step)
+ else:
+ md_timesteps = None
+
+ sai_metadata = sai_model_spec.build_metadata(
+ None,
+ time.time(),
+ title,
+ None,
+ args.metadata_author,
+ args.metadata_description,
+ args.metadata_license,
+ args.metadata_tags,
+ timesteps=md_timesteps,
+ )
+
+ metadata_to_save.update(sai_metadata)
+
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
+ if args.huggingface_repo_id is not None:
+ huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
+
+ def remove_model(old_ckpt_name):
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
+ if os.path.exists(old_ckpt_file):
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
+ os.remove(old_ckpt_file)
+
+ # For --sample_at_first
+ optimizer_eval_fn()
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
+ optimizer_train_fn()
+ if len(accelerator.trackers) > 0:
+ # log empty object to commit the sample images to wandb
+ accelerator.log({}, step=0)
+
+ # training loop
+
+ # log device and dtype for each model
+ logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
+
+ clean_memory_on_device(accelerator.device)
+
+ pos_embed_cache = {}
+
+ for epoch in range(epoch_to_start, num_train_epochs):
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
+ current_epoch.value = epoch + 1
+
+ metadata["ss_epoch"] = str(epoch + 1)
+
+ accelerator.unwrap_model(network).on_epoch_start(transformer)
+
+ for step, batch in enumerate(train_dataloader):
+ latents, llm_embeds, llm_mask, clip_embeds = batch
+ bsz = latents.shape[0]
+ current_step.value = global_step
+
+ with accelerator.accumulate(training_model):
+ accelerator.unwrap_model(network).on_step_start()
+
+ latents = latents * vae_module.SCALING_FACTOR
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+
+ # calculate model input and timesteps
+ noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
+ args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
+ )
+
+ weighting = compute_loss_weighting_for_sd3(
+ args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
+ )
+
+ # ensure guidance_scale in args is float
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
+
+ # ensure the hidden state will require grad
+ if args.gradient_checkpointing:
+ noisy_model_input.requires_grad_(True)
+ guidance_vec.requires_grad_(True)
+
+ pos_emb_shape = latents.shape[1:]
+ if pos_emb_shape not in pos_embed_cache:
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:])
+ # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
+ # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
+ pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
+ else:
+ freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
+
+ # call DiT
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
+ # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
+ # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
+ # llm_mask = llm_mask.to(device=accelerator.device)
+ # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
+ with accelerator.autocast():
+ model_pred = transformer(
+ noisy_model_input,
+ timesteps,
+ text_states=llm_embeds,
+ text_mask=llm_mask,
+ text_states_2=clip_embeds,
+ freqs_cos=freqs_cos,
+ freqs_sin=freqs_sin,
+ guidance=guidance_vec,
+ return_dict=False,
+ )
+
+ # flow matching loss
+ target = noise - latents
+
+ loss = torch.nn.functional.mse_loss(model_pred.to(network_dtype), target, reduction="none")
+
+ if weighting is not None:
+ loss = loss * weighting
+ # loss = loss.mean([1, 2, 3])
+ # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
+ # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
+
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ # self.all_reduce_network(accelerator, network) # sync DDP grad manually
+ state = accelerate.PartialState()
+ if state.distributed_type != accelerate.DistributedType.NO:
+ for param in network.parameters():
+ if param.grad is not None:
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
+
+ if args.max_grad_norm != 0.0:
+ params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ if args.scale_weight_norms:
+ keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
+ args.scale_weight_norms, accelerator.device
+ )
+ max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
+ else:
+ keys_scaled, mean_norm, maximum_norm = None, None, None
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ optimizer_eval_fn()
+ self.sample_images(
+ accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
+ )
+
+ # 指定ステップごとにモデルを保存
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
+
+ if args.save_state:
+ train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
+
+ remove_step_no = train_utils.get_remove_step_no(args, global_step)
+ if remove_step_no is not None:
+ remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
+ remove_model(remove_ckpt_name)
+ optimizer_train_fn()
+
+ current_loss = loss.detach().item()
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
+ avr_loss: float = loss_recorder.moving_average
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if args.scale_weight_norms:
+ progress_bar.set_postfix(**{**max_mean_logs, **logs})
+
+ if len(accelerator.trackers) > 0:
+ logs = self.generate_step_logs(
+ args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
+ )
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if len(accelerator.trackers) > 0:
+ logs = {"loss/epoch": loss_recorder.moving_average}
+ accelerator.log(logs, step=epoch + 1)
+
+ accelerator.wait_for_everyone()
+
+ # 指定エポックごとにモデルを保存
+ optimizer_eval_fn()
+ if args.save_every_n_epochs is not None:
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
+ if is_main_process and saving:
+ ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
+
+ remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
+ if remove_epoch_no is not None:
+ remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
+ remove_model(remove_ckpt_name)
+
+ if args.save_state:
+ train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
+
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
+ optimizer_train_fn()
+
+ # end of epoch
+
+ # metadata["ss_epoch"] = str(num_train_epochs)
+ metadata["ss_training_finished_at"] = str(time.time())
+
+ if is_main_process:
+ network = accelerator.unwrap_model(network)
+
+ accelerator.end_training()
+ optimizer_eval_fn()
+
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
+ train_utils.save_state_on_train_end(args, accelerator)
+
+ if is_main_process:
+ ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
+
+ logger.info("model saved.")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ def int_or_float(value):
+ if value.endswith("%"):
+ try:
+ return float(value[:-1]) / 100.0
+ except ValueError:
+ raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
+ try:
+ float_value = float(value)
+ if float_value >= 1 and float_value.is_integer():
+ return int(value)
+ return float(value)
+ except ValueError:
+ raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
+
+ parser = argparse.ArgumentParser()
+
+ # general settings
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ default=None,
+ help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
+ )
+ parser.add_argument(
+ "--dataset_config",
+ type=pathlib.Path,
+ default=None,
+ required=True,
+ help="config file for dataset / データセットの設定ファイル",
+ )
+
+ # training settings
+ parser.add_argument(
+ "--sdpa",
+ action="store_true",
+ help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
+ )
+ parser.add_argument(
+ "--flash_attn",
+ action="store_true",
+ help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
+ )
+ parser.add_argument(
+ "--sage_attn",
+ action="store_true",
+ help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
+ )
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
+ parser.add_argument(
+ "--max_train_epochs",
+ type=int,
+ default=None,
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
+ )
+ parser.add_argument(
+ "--max_data_loader_n_workers",
+ type=int,
+ default=8,
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
+ )
+ parser.add_argument(
+ "--persistent_data_loader_workers",
+ action="store_true",
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
+ parser.add_argument(
+ "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help="use mixed precision / 混合精度を使う場合、その精度",
+ )
+
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default=None,
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
+ )
+ parser.add_argument(
+ "--log_with",
+ type=str,
+ default=None,
+ choices=["tensorboard", "wandb", "all"],
+ help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
+ )
+ parser.add_argument(
+ "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
+ )
+ parser.add_argument(
+ "--log_tracker_name",
+ type=str,
+ default=None,
+ help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
+ )
+ parser.add_argument(
+ "--wandb_run_name",
+ type=str,
+ default=None,
+ help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
+ )
+ parser.add_argument(
+ "--log_tracker_config",
+ type=str,
+ default=None,
+ help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
+ )
+ parser.add_argument(
+ "--wandb_api_key",
+ type=str,
+ default=None,
+ help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
+ )
+ parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
+
+ parser.add_argument(
+ "--ddp_timeout",
+ type=int,
+ default=None,
+ help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
+ )
+ parser.add_argument(
+ "--ddp_gradient_as_bucket_view",
+ action="store_true",
+ help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
+ )
+ parser.add_argument(
+ "--ddp_static_graph",
+ action="store_true",
+ help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
+ )
+
+ parser.add_argument(
+ "--sample_every_n_steps",
+ type=int,
+ default=None,
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
+ )
+ parser.add_argument(
+ "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
+ )
+ parser.add_argument(
+ "--sample_every_n_epochs",
+ type=int,
+ default=None,
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
+ )
+ parser.add_argument(
+ "--sample_prompts",
+ type=str,
+ default=None,
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
+ )
+
+ # optimizer and lr scheduler settings
+ parser.add_argument(
+ "--optimizer_type",
+ type=str,
+ default="",
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
+ "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
+ )
+ parser.add_argument(
+ "--optimizer_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
+ )
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
+ parser.add_argument(
+ "--max_grad_norm",
+ default=1.0,
+ type=float,
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
+ )
+
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps",
+ type=int_or_float,
+ default=0,
+ help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
+ " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
+ )
+ parser.add_argument(
+ "--lr_decay_steps",
+ type=int_or_float,
+ default=0,
+ help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
+ " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
+ )
+ parser.add_argument(
+ "--lr_scheduler_num_cycles",
+ type=int,
+ default=1,
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
+ )
+ parser.add_argument(
+ "--lr_scheduler_power",
+ type=float,
+ default=1,
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
+ )
+ parser.add_argument(
+ "--lr_scheduler_timescale",
+ type=int,
+ default=None,
+ help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
+ + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
+ )
+ parser.add_argument(
+ "--lr_scheduler_min_lr_ratio",
+ type=float,
+ default=None,
+ help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
+ + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
+ )
+ parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
+ parser.add_argument(
+ "--lr_scheduler_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
+ )
+
+ # model settings
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
+ parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
+ parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
+ parser.add_argument(
+ "--vae_tiling",
+ action="store_true",
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
+ " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
+ parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
+ # parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
+ # parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
+
+ parser.add_argument(
+ "--blocks_to_swap",
+ type=int,
+ default=None,
+ help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
+ )
+ parser.add_argument(
+ "--img_in_txt_in_offloading",
+ action="store_true",
+ help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
+ )
+
+ # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
+ parser.add_argument(
+ "--timestep_sampling",
+ choices=["sigma", "uniform", "sigmoid", "shift"],
+ default="sigma",
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
+ )
+ parser.add_argument(
+ "--discrete_flow_shift",
+ type=float,
+ default=1.0,
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
+ )
+ parser.add_argument(
+ "--sigmoid_scale",
+ type=float,
+ default=1.0,
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
+ help="weighting scheme for timestep distribution. Default is none"
+ " / タイムステップ分布の重み付けスキーム、デフォルトはnone",
+ )
+ parser.add_argument(
+ "--logit_mean",
+ type=float,
+ default=0.0,
+ help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
+ )
+ parser.add_argument(
+ "--logit_std",
+ type=float,
+ default=1.0,
+ help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
+ )
+ parser.add_argument(
+ "--min_timestep",
+ type=int,
+ default=None,
+ help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
+ )
+ parser.add_argument(
+ "--max_timestep",
+ type=int,
+ default=None,
+ help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
+ )
+
+ parser.add_argument(
+ "--show_timesteps",
+ type=str,
+ default=None,
+ choices=["image", "console"],
+ help="show timesteps in image or console, and return to console / タイムステップを画像またはコンソールに表示し、コンソールに戻る",
+ )
+
+ # network settings
+ parser.add_argument(
+ "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
+ )
+ parser.add_argument(
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
+ )
+ parser.add_argument(
+ "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
+ )
+ parser.add_argument(
+ "--network_dim",
+ type=int,
+ default=None,
+ help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
+ )
+ parser.add_argument(
+ "--network_alpha",
+ type=float,
+ default=1,
+ help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
+ )
+ parser.add_argument(
+ "--network_dropout",
+ type=float,
+ default=None,
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
+ )
+ parser.add_argument(
+ "--network_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help="additional arguments for network (key=value) / ネットワークへの追加の引数",
+ )
+ parser.add_argument(
+ "--training_comment",
+ type=str,
+ default=None,
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
+ )
+ parser.add_argument(
+ "--dim_from_weights",
+ action="store_true",
+ help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
+ )
+ parser.add_argument(
+ "--scale_weight_norms",
+ type=float,
+ default=None,
+ help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
+ )
+ parser.add_argument(
+ "--base_weights",
+ type=str,
+ default=None,
+ nargs="*",
+ help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
+ )
+ parser.add_argument(
+ "--base_weights_multiplier",
+ type=float,
+ default=None,
+ nargs="*",
+ help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
+ )
+
+ # save and load settings
+ parser.add_argument(
+ "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
+ )
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default=None,
+ required=True,
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
+ )
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
+
+ parser.add_argument(
+ "--save_every_n_epochs",
+ type=int,
+ default=None,
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
+ )
+ parser.add_argument(
+ "--save_every_n_steps",
+ type=int,
+ default=None,
+ help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
+ )
+ parser.add_argument(
+ "--save_last_n_epochs",
+ type=int,
+ default=None,
+ help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
+ )
+ parser.add_argument(
+ "--save_last_n_epochs_state",
+ type=int,
+ default=None,
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
+ )
+ parser.add_argument(
+ "--save_last_n_steps",
+ type=int,
+ default=None,
+ help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
+ )
+ parser.add_argument(
+ "--save_last_n_steps_state",
+ type=int,
+ default=None,
+ help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
+ )
+ parser.add_argument(
+ "--save_state",
+ action="store_true",
+ help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
+ )
+ parser.add_argument(
+ "--save_state_on_train_end",
+ action="store_true",
+ help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
+ " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
+ )
+
+ # SAI Model spec
+ parser.add_argument(
+ "--metadata_title",
+ type=str,
+ default=None,
+ help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
+ )
+ parser.add_argument(
+ "--metadata_author",
+ type=str,
+ default=None,
+ help="author name for model metadata / メタデータに書き込まれるモデル作者名",
+ )
+ parser.add_argument(
+ "--metadata_description",
+ type=str,
+ default=None,
+ help="description for model metadata / メタデータに書き込まれるモデル説明",
+ )
+ parser.add_argument(
+ "--metadata_license",
+ type=str,
+ default=None,
+ help="license for model metadata / メタデータに書き込まれるモデルライセンス",
+ )
+ parser.add_argument(
+ "--metadata_tags",
+ type=str,
+ default=None,
+ help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
+ )
+
+ # huggingface settings
+ parser.add_argument(
+ "--huggingface_repo_id",
+ type=str,
+ default=None,
+ help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
+ )
+ parser.add_argument(
+ "--huggingface_repo_type",
+ type=str,
+ default=None,
+ help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
+ )
+ parser.add_argument(
+ "--huggingface_path_in_repo",
+ type=str,
+ default=None,
+ help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
+ )
+ parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
+ parser.add_argument(
+ "--huggingface_repo_visibility",
+ type=str,
+ default=None,
+ help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
+ )
+ parser.add_argument(
+ "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
+ )
+ parser.add_argument(
+ "--resume_from_huggingface",
+ action="store_true",
+ help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
+ )
+ parser.add_argument(
+ "--async_upload",
+ action="store_true",
+ help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
+ )
+
+ return parser
+
+
+def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
+ if not args.config_file:
+ return args
+
+ config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
+
+ if not os.path.exists(config_path):
+ logger.info(f"{config_path} not found.")
+ exit(1)
+
+ logger.info(f"Loading settings from {config_path}...")
+ with open(config_path, "r", encoding="utf-8") as f:
+ config_dict = toml.load(f)
+
+ # combine all sections into one
+ ignore_nesting_dict = {}
+ for section_name, section_dict in config_dict.items():
+ # if value is not dict, save key and value as is
+ if not isinstance(section_dict, dict):
+ ignore_nesting_dict[section_name] = section_dict
+ continue
+
+ # if value is dict, save all key and value into one dict
+ for key, value in section_dict.items():
+ ignore_nesting_dict[key] = value
+
+ config_args = argparse.Namespace(**ignore_nesting_dict)
+ args = parser.parse_args(namespace=config_args)
+ args.config_file = os.path.splitext(args.config_file)[0]
+ logger.info(args.config_file)
+
+ return args
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+
+ args = parser.parse_args()
+ args = read_config_from_file(args, parser)
+
+ trainer = NetworkTrainer()
+ trainer.train(args)
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/custom_offloading_utils.py b/modules/custom_offloading_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f9db6543731c1b914484735830e209df27b1738
--- /dev/null
+++ b/modules/custom_offloading_utils.py
@@ -0,0 +1,262 @@
+from concurrent.futures import ThreadPoolExecutor
+import gc
+import time
+from typing import Optional
+import torch
+import torch.nn as nn
+
+
+def clean_memory_on_device(device: torch.device):
+ r"""
+ Clean memory on the specified device, will be called from training scripts.
+ """
+ gc.collect()
+
+ # device may "cuda" or "cuda:0", so we need to check the type of device
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ if device.type == "xpu":
+ torch.xpu.empty_cache()
+ if device.type == "mps":
+ torch.mps.empty_cache()
+
+
+def synchronize_device(device: torch.device):
+ if device.type == "cuda":
+ torch.cuda.synchronize()
+ elif device.type == "xpu":
+ torch.xpu.synchronize()
+ elif device.type == "mps":
+ torch.mps.synchronize()
+
+
+def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
+
+ weight_swap_jobs = []
+
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+ else:
+ if module_to_cuda.weight.data.device.type != device.type:
+ # print(
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
+ # )
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
+
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
+
+ stream = torch.cuda.Stream()
+ with torch.cuda.stream(stream):
+ # cuda to cpu
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ cuda_data_view.record_stream(stream)
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
+
+ stream.synchronize()
+
+ # cpu to cuda
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
+ module_to_cuda.weight.data = cuda_data_view
+
+ stream.synchronize()
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
+
+
+def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
+ """
+ not tested
+ """
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
+
+ weight_swap_jobs = []
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+
+ # device to cpu
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
+
+ synchronize_device()
+
+ # cpu to device
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
+ module_to_cuda.weight.data = cuda_data_view
+
+ synchronize_device()
+
+
+def weighs_to_device(layer: nn.Module, device: torch.device):
+ for module in layer.modules():
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
+
+
+class Offloader:
+ """
+ common offloading class
+ """
+
+ def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
+ self.block_type = block_type
+ self.num_blocks = num_blocks
+ self.blocks_to_swap = blocks_to_swap
+ self.device = device
+ self.debug = debug
+
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
+ self.futures = {}
+ self.cuda_available = device.type == "cuda"
+
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
+ if self.cuda_available:
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
+ else:
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
+
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
+ if self.debug:
+ start_time = time.perf_counter()
+ print(
+ f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
+ )
+
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
+
+ if self.debug:
+ print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
+ return bidx_to_cpu, bidx_to_cuda # , event
+
+ block_to_cpu = blocks[block_idx_to_cpu]
+ block_to_cuda = blocks[block_idx_to_cuda]
+
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
+ )
+
+ def _wait_blocks_move(self, block_idx):
+ if block_idx not in self.futures:
+ return
+
+ if self.debug:
+ print(f"[{self.block_type}] Wait for block {block_idx}")
+ start_time = time.perf_counter()
+
+ future = self.futures.pop(block_idx)
+ _, bidx_to_cuda = future.result()
+
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
+
+ if self.debug:
+ print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
+
+
+class ModelOffloader(Offloader):
+ """
+ supports forward offloading
+ """
+
+ def __init__(
+ self,
+ block_type: str,
+ blocks: list[nn.Module],
+ num_blocks: int,
+ blocks_to_swap: int,
+ supports_backward: bool,
+ device: torch.device,
+ debug: bool = False,
+ ):
+ super().__init__(block_type, num_blocks, blocks_to_swap, device, debug)
+
+ self.supports_backward = supports_backward
+
+ if self.supports_backward:
+ # register backward hooks
+ self.remove_handles = []
+ for i, block in enumerate(blocks):
+ hook = self.create_backward_hook(blocks, i)
+ if hook is not None:
+ handle = block.register_full_backward_hook(hook)
+ self.remove_handles.append(handle)
+
+ def __del__(self):
+ if self.supports_backward:
+ for handle in self.remove_handles:
+ handle.remove()
+
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
+ # -1 for 0-based index
+ num_blocks_propagated = self.num_blocks - block_index - 1
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
+
+ if not swapping and not waiting:
+ return None
+
+ # create hook
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
+ block_idx_to_wait = block_index - 1
+
+ def backward_hook(module, grad_input, grad_output):
+ if self.debug:
+ print(f"Backward hook for block {block_index}")
+
+ if swapping:
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
+ if waiting:
+ self._wait_blocks_move(block_idx_to_wait)
+ return None
+
+ return backward_hook
+
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+
+ if self.debug:
+ print(f"[{self.block_type}] Prepare block devices before forward")
+
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
+ b.to(self.device)
+ weighs_to_device(b, self.device) # make sure weights are on device
+
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
+ b.to(self.device) # move block to device first
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
+
+ synchronize_device(self.device)
+ clean_memory_on_device(self.device)
+
+ def wait_for_block(self, block_idx: int):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self._wait_blocks_move(block_idx)
+
+ def submit_move_blocks_forward(self, blocks: list[nn.Module], block_idx: int):
+ # check if blocks_to_swap is enabled
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+
+ # if supports_backward, we swap blocks more than blocks_to_swap in backward pass
+ if self.supports_backward and block_idx >= self.blocks_to_swap:
+ return
+
+ block_idx_to_cpu = block_idx
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
+ block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
diff --git a/modules/scheduling_flow_match_discrete.py b/modules/scheduling_flow_match_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..c507ec4eb050463188e250c20aec8d1fde2c4a5d
--- /dev/null
+++ b/modules/scheduling_flow_match_discrete.py
@@ -0,0 +1,257 @@
+# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FlowMatchDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ reverse (`bool`, defaults to `True`):
+ Whether to reverse the timestep schedule.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ reverse: bool = True,
+ solver: str = "euler",
+ n_tokens: Optional[int] = None,
+ ):
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
+
+ if not reverse:
+ sigmas = sigmas.flip(0)
+
+ self.sigmas = sigmas
+ # the value fed to model
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
+
+ self._step_index = None
+ self._begin_index = None
+
+ self.supported_solver = ["euler"]
+ if solver not in self.supported_solver:
+ raise ValueError(
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
+ )
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ n_tokens: int = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ n_tokens (`int`, *optional*):
+ Number of tokens in the input sequence.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
+ sigmas = self.sd3_time_shift(sigmas)
+
+ if not self.config.reverse:
+ sigmas = 1 - sigmas
+
+ self.sigmas = sigmas
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
+ dtype=torch.float32, device=device
+ )
+
+ # Reset step index
+ self._step_index = None
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def scale_model_input(
+ self, sample: torch.Tensor, timestep: Optional[int] = None
+ ) -> torch.Tensor:
+ return sample
+
+ def sd3_time_shift(self, t: torch.Tensor):
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ n_tokens (`int`, *optional*):
+ Number of tokens in the input sequence.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
+
+ if self.config.solver == "euler":
+ prev_sample = sample + model_output.to(torch.float32) * dt
+ else:
+ raise ValueError(
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
+ )
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/modules/unet_causal_3d_blocks.py b/modules/unet_causal_3d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..27d544170ece6a370cdacfe9e31367b884c2e516
--- /dev/null
+++ b/modules/unet_causal_3d_blocks.py
@@ -0,0 +1,818 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from einops import rearrange
+
+from diffusers.utils import logging
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import SpatialNorm
+from diffusers.models.attention_processor import Attention
+from diffusers.models.normalization import AdaGroupNorm
+from diffusers.models.normalization import RMSNorm
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
+ seq_len = n_frame * n_hw
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
+ for i in range(seq_len):
+ i_frame = i // n_hw
+ mask[i, : (i_frame + 1) * n_hw] = 0
+ if batch_size is not None:
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
+ return mask
+
+
+class CausalConv3d(nn.Module):
+ """
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
+ This maintains temporal causality in video generation tasks.
+ """
+
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ pad_mode="replicate",
+ chunk_size=0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
+ self.time_causal_padding = padding
+ self.chunk_size = chunk_size
+
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def original_forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+ def forward(self, x):
+ if self.chunk_size == 0:
+ return self.original_forward(x)
+
+ # if not large, call original forward
+ if x.shape[4] < self.chunk_size * 1.5:
+ return self.original_forward(x)
+
+ # # debug: verify the original forward is the same as chunked forward
+ # orig_forwarded_value = None
+ # if x.shape[4] < self.chunk_size * 4:
+ # orig_forwarded_value = self.original_forward(x)
+
+ # get the kernel size
+ kernel_size = self.conv.kernel_size[0] # assume cubic kernel
+ assert kernel_size == self.conv.kernel_size[1] == self.conv.kernel_size[2], "Only cubic kernels are supported"
+ padding_size = kernel_size // 2 # 1 for kernel_size=3, 0 for kernel_size=1
+
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+
+ B, C, D, H, W = orig_shape = x.shape
+ chunk_size = self.chunk_size
+ chunk_size -= chunk_size % self.conv.stride[2] # make sure the chunk size is divisible by stride
+ # print(f"chunked forward: {x.shape}, chunk_size: {chunk_size}")
+
+ # calculate the indices for chunking with overlap and padding by kernel size and stride
+ indices = []
+ i = 0
+ while i < W - padding_size:
+ start_idx = i - padding_size
+ end_idx = min(i + chunk_size + padding_size, W)
+ if i == 0:
+ start_idx = 0
+ end_idx += padding_size # to make sure the first chunk is divisible by stride
+ if W - end_idx < chunk_size // 2: # small chunk at the end
+ end_idx = W
+ indices.append((start_idx, end_idx))
+ i = end_idx - padding_size
+ # print(f"chunked forward: {x.shape}, chunked indices: {indices}")
+
+ chunks = []
+ for start_idx, end_idx in indices:
+ chunk = x[:, :, :, :, start_idx:end_idx]
+ chunk_output = self.conv(chunk)
+ # print(chunk.shape, chunk_output.shape)
+ chunks.append(chunk_output)
+
+ # concatenate the chunks
+ x = torch.cat(chunks, dim=4)
+
+ assert (
+ x.shape[2] == ((D - padding_size * 2) + self.conv.stride[0] - 1) // self.conv.stride[0]
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
+ assert (
+ x.shape[3] == ((H - padding_size * 2) + self.conv.stride[1] - 1) // self.conv.stride[1]
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
+ assert (
+ x.shape[4] == ((W - padding_size * 2) + self.conv.stride[2] - 1) // self.conv.stride[2]
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
+
+ # # debug: verify the original forward is the same as chunked forward
+ # if orig_forwarded_value is not None:
+ # assert torch.allclose(
+ # orig_forwarded_value, x, rtol=1e-4, atol=1e-2
+ # ), f"Chunked forward is different from original forward. {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}, {self.conv.kernel_size}"
+
+ return x
+
+
+class UpsampleCausal3D(nn.Module):
+ """
+ A 3D upsampling layer with an optional convolution.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ use_conv_transpose: bool = False,
+ out_channels: Optional[int] = None,
+ name: str = "conv",
+ kernel_size: Optional[int] = None,
+ padding=1,
+ norm_type=None,
+ eps=None,
+ elementwise_affine=None,
+ bias=True,
+ interpolate=True,
+ upsample_factor=(2, 2, 2),
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+ self.interpolate = interpolate
+ self.upsample_factor = upsample_factor
+
+ if norm_type == "ln_norm":
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
+ elif norm_type is None:
+ self.norm = None
+ else:
+ raise ValueError(f"unknown norm_type: {norm_type}")
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ if kernel_size is None:
+ kernel_size = 3
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ output_size: Optional[int] = None,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ assert hidden_states.shape[1] == self.channels
+
+ if self.norm is not None:
+ raise NotImplementedError
+
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if self.interpolate:
+ B, C, T, H, W = hidden_states.shape
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
+ if output_size is None:
+ if T > 1:
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
+
+ first_h = first_h.squeeze(2)
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
+ first_h = first_h.unsqueeze(2)
+ else:
+ raise NotImplementedError
+
+ if T > 1:
+ hidden_states = torch.cat((first_h, other_h), dim=2)
+ else:
+ hidden_states = first_h
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class DownsampleCausal3D(nn.Module):
+ """
+ A 3D downsampling layer with an optional convolution.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ name: str = "conv",
+ kernel_size=3,
+ norm_type=None,
+ eps=None,
+ elementwise_affine=None,
+ bias=True,
+ stride=2,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = stride
+ self.name = name
+
+ if norm_type == "ln_norm":
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
+ elif norm_type is None:
+ self.norm = None
+ else:
+ raise ValueError(f"unknown norm_type: {norm_type}")
+
+ if use_conv:
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ assert hidden_states.shape[1] == self.channels
+
+ if self.norm is not None:
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ assert hidden_states.shape[1] == self.channels
+
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlockCausal3D(nn.Module):
+ r"""
+ A Resnet block.
+ """
+
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ conv_shortcut: bool = False,
+ dropout: float = 0.0,
+ temb_channels: int = 512,
+ groups: int = 32,
+ groups_out: Optional[int] = None,
+ pre_norm: bool = True,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ skip_time_act: bool = False,
+ # default, scale_shift, ada_group, spatial
+ time_embedding_norm: str = "default",
+ kernel: Optional[torch.FloatTensor] = None,
+ output_scale_factor: float = 1.0,
+ use_in_shortcut: Optional[bool] = None,
+ up: bool = False,
+ down: bool = False,
+ conv_shortcut_bias: bool = True,
+ conv_3d_out_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+ self.time_embedding_norm = time_embedding_norm
+ self.skip_time_act = skip_time_act
+
+ linear_cls = nn.Linear
+
+ if groups_out is None:
+ groups_out = groups
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
+ elif self.time_embedding_norm == "spatial":
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
+ else:
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
+ elif self.time_embedding_norm == "scale_shift":
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ self.time_emb_proj = None
+ else:
+ raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
+ else:
+ self.time_emb_proj = None
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
+ elif self.time_embedding_norm == "spatial":
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
+ else:
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
+ self.dropout = torch.nn.Dropout(dropout)
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.upsample = self.downsample = None
+ if self.up:
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
+ elif self.down:
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
+
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = CausalConv3d(
+ in_channels,
+ conv_3d_out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=conv_shortcut_bias,
+ )
+
+ def forward(
+ self,
+ input_tensor: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ hidden_states = input_tensor
+
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ hidden_states = self.norm1(hidden_states, temb)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ input_tensor = input_tensor.contiguous()
+ hidden_states = hidden_states.contiguous()
+ input_tensor = self.upsample(input_tensor, scale=scale)
+ hidden_states = self.upsample(hidden_states, scale=scale)
+ elif self.downsample is not None:
+ input_tensor = self.downsample(input_tensor, scale=scale)
+ hidden_states = self.downsample(hidden_states, scale=scale)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if self.time_emb_proj is not None:
+ if not self.skip_time_act:
+ temb = self.nonlinearity(temb)
+ temb = self.time_emb_proj(temb, scale)[:, :, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ hidden_states = self.norm2(hidden_states, temb)
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+def get_down_block3d(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ add_downsample: bool,
+ downsample_stride: int,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ downsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownEncoderBlockCausal3D":
+ return DownEncoderBlockCausal3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ downsample_stride=downsample_stride,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block3d(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ add_upsample: bool,
+ upsample_scale_factor: Tuple,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resolution_idx: Optional[int] = None,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ upsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+) -> nn.Module:
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpDecoderBlockCausal3D":
+ return UpDecoderBlockCausal3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ upsample_scale_factor=upsample_scale_factor,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlockCausal3D(nn.Module):
+ """
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ attn_groups: Optional[int] = None,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ if attn_groups is None:
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+ )
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=attn_groups,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ B, C, T, H, W = hidden_states.shape
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
+ attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class DownEncoderBlockCausal3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_stride: int = 2,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ DownsampleCausal3D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ stride=downsample_stride,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale)
+
+ return hidden_states
+
+
+class UpDecoderBlockCausal3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ upsample_scale_factor=(2, 2, 2),
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ UpsampleCausal3D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ upsample_factor=upsample_scale_factor,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
diff --git a/networks/__init__.py b/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/networks/lora.py b/networks/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb8912703dc7cfbbe0180f3fa5b794d6d48c334
--- /dev/null
+++ b/networks/lora.py
@@ -0,0 +1,828 @@
+# LoRA network module: currently conv2d is not fully supported
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+
+import math
+import os
+from typing import Dict, List, Optional, Type, Union
+from diffusers import AutoencoderKL
+from transformers import CLIPTextModel
+import numpy as np
+import torch
+import torch.nn as nn
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"]
+
+
+class LoRAModule(torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ split_dims: Optional[List[int]] = None,
+ ):
+ """
+ if alpha == 0 or None, alpha is rank (no scaling).
+
+ split_dims is used to mimic the split qkv of multi-head attention.
+ """
+ super().__init__()
+ self.lora_name = lora_name
+
+ if org_module.__class__.__name__ == "Conv2d":
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ self.lora_dim = lora_dim
+ self.split_dims = split_dims
+
+ if split_dims is None:
+ if org_module.__class__.__name__ == "Conv2d":
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+ else:
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.zeros_(self.lora_up.weight)
+ else:
+ # conv2d not supported
+ assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
+ assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
+ # print(f"split_dims: {split_dims}")
+ self.lora_down = torch.nn.ModuleList(
+ [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
+ )
+ self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
+ for lora_down in self.lora_down:
+ torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
+ for lora_up in self.lora_up:
+ torch.nn.init.zeros_(lora_up.weight)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha)) # for save/load
+
+ # same as microsoft's
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def forward(self, x):
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ if self.split_dims is None:
+ lx = self.lora_down(x)
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if len(lx.size()) == 3:
+ mask = mask.unsqueeze(1) # for Text Encoder
+ elif len(lx.size()) == 4:
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ return org_forwarded + lx * self.multiplier * scale
+ else:
+ lxs = [lora_down(x) for lora_down in self.lora_down]
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
+ for i in range(len(lxs)):
+ if len(lx.size()) == 3:
+ masks[i] = masks[i].unsqueeze(1)
+ elif len(lx.size()) == 4:
+ masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
+ lxs[i] = lxs[i] * masks[i]
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
+
+ return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
+
+
+class LoRAInfModule(LoRAModule):
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ **kwargs,
+ ):
+ # no dropout for inference
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
+
+ self.org_module_ref = [org_module] # for reference
+ self.enabled = True
+ self.network: LoRANetwork = None
+
+ def set_network(self, network):
+ self.network = network
+
+ # merge weight to org_module
+ def merge_to(self, sd, dtype, device):
+ # extract weight from org_module
+ org_sd = self.org_module.state_dict()
+ weight = org_sd["weight"]
+ org_dtype = weight.dtype
+ org_device = weight.device
+ weight = weight.to(device, dtype=torch.float) # for calculation
+
+ if dtype is None:
+ dtype = org_dtype
+ if device is None:
+ device = org_device
+
+ if self.split_dims is None:
+ # get up/down weight
+ down_weight = sd["lora_down.weight"].to(device, dtype=torch.float)
+ up_weight = sd["lora_up.weight"].to(device, dtype=torch.float)
+
+ # merge weight
+ if len(weight.size()) == 2:
+ # linear
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ weight
+ + self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
+ weight = weight + self.multiplier * conved * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(org_device, dtype=dtype)
+ self.org_module.load_state_dict(org_sd)
+ else:
+ # split_dims
+ total_dims = sum(self.split_dims)
+ for i in range(len(self.split_dims)):
+ # get up/down weight
+ down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
+ up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
+
+ # pad up_weight -> (total_dims, rank)
+ padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
+ padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
+
+ # merge weight
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(dtype)
+ self.org_module.load_state_dict(org_sd)
+
+ # return weight for merge
+ def get_weight(self, multiplier=None):
+ if multiplier is None:
+ multiplier = self.multiplier
+
+ # get up/down weight from module
+ up_weight = self.lora_up.weight.to(torch.float)
+ down_weight = self.lora_down.weight.to(torch.float)
+
+ # pre-calculated weight
+ if len(down_weight.size()) == 2:
+ # linear
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ weight = self.multiplier * conved * self.scale
+
+ return weight
+
+ def default_forward(self, x):
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
+ if self.split_dims is None:
+ lx = self.lora_down(x)
+ lx = self.lora_up(lx)
+ return self.org_forward(x) + lx * self.multiplier * self.scale
+ else:
+ lxs = [lora_down(x) for lora_down in self.lora_down]
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
+ return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
+
+ def forward(self, x):
+ if not self.enabled:
+ return self.org_forward(x)
+ return self.default_forward(x)
+
+
+def create_network_hunyuan_video(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae: nn.Module,
+ text_encoders: List[nn.Module],
+ unet: nn.Module,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ return create_network(
+ HUNYUAN_TARGET_REPLACE_MODULES,
+ "lora_unet",
+ multiplier,
+ network_dim,
+ network_alpha,
+ vae,
+ text_encoders,
+ unet,
+ neuron_dropout=neuron_dropout,
+ **kwargs,
+ )
+
+
+def create_network(
+ target_replace_modules: List[str],
+ prefix: str,
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae: nn.Module,
+ text_encoders: List[nn.Module],
+ unet: nn.Module,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ if network_dim is None:
+ network_dim = 4 # default
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ # extract dim/alpha for conv2d, and block dim
+ conv_dim = kwargs.get("conv_dim", None)
+ conv_alpha = kwargs.get("conv_alpha", None)
+ if conv_dim is not None:
+ conv_dim = int(conv_dim)
+ if conv_alpha is None:
+ conv_alpha = 1.0
+ else:
+ conv_alpha = float(conv_alpha)
+
+ # TODO generic rank/dim setting with regular expression
+
+ # rank/module dropout
+ rank_dropout = kwargs.get("rank_dropout", None)
+ if rank_dropout is not None:
+ rank_dropout = float(rank_dropout)
+ module_dropout = kwargs.get("module_dropout", None)
+ if module_dropout is not None:
+ module_dropout = float(module_dropout)
+
+ # verbose
+ verbose = kwargs.get("verbose", False)
+ if verbose is not None:
+ verbose = True if verbose == "True" else False
+
+ # too many arguments ( ^ω^)・・・
+ network = LoRANetwork(
+ target_replace_modules,
+ prefix,
+ text_encoders,
+ unet,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ conv_lora_dim=conv_dim,
+ conv_alpha=conv_alpha,
+ verbose=verbose,
+ )
+
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
+ # loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
+ # loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
+ # loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
+ # loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
+ if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
+
+ return network
+
+
+class LoRANetwork(torch.nn.Module):
+ # only supports U-Net (DiT), Text Encoders are not supported
+
+ def __init__(
+ self,
+ target_replace_modules: List[str],
+ prefix: str,
+ text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
+ unet: nn.Module,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ conv_lora_dim: Optional[int] = None,
+ conv_alpha: Optional[float] = None,
+ module_class: Type[object] = LoRAModule,
+ modules_dim: Optional[Dict[str, int]] = None,
+ modules_alpha: Optional[Dict[str, int]] = None,
+ verbose: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+ self.multiplier = multiplier
+
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.conv_lora_dim = conv_lora_dim
+ self.conv_alpha = conv_alpha
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.target_replace_modules = target_replace_modules
+ self.prefix = prefix
+
+ self.loraplus_lr_ratio = None
+ # self.loraplus_unet_lr_ratio = None
+ # self.loraplus_text_encoder_lr_ratio = None
+
+ if modules_dim is not None:
+ logger.info(f"create LoRA network from weights")
+ else:
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ logger.info(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
+ )
+ # if self.conv_lora_dim is not None:
+ # logger.info(
+ # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
+ # )
+ # if train_t5xxl:
+ # logger.info(f"train T5XXL as well")
+
+ # create module instances
+ def create_modules(
+ is_unet: bool,
+ pfx: str,
+ root_module: torch.nn.Module,
+ target_replace_mods: List[str],
+ filter: Optional[str] = None,
+ default_dim: Optional[int] = None,
+ ) -> List[LoRAModule]:
+ loras = []
+ skipped = []
+ for name, module in root_module.named_modules():
+ if target_replace_mods is None or module.__class__.__name__ in target_replace_mods:
+ if target_replace_mods is None: # dirty hack for all modules
+ module = root_module # search all modules
+
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ == "Linear"
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+ if is_linear or is_conv2d:
+ original_name = (name + "." if name else "") + child_name
+ lora_name = f"{pfx}.{original_name}".replace(".", "_")
+
+ if filter is not None and not filter in lora_name:
+ continue
+
+ dim = None
+ alpha = None
+
+ if modules_dim is not None:
+ # モジュール指定あり
+ if lora_name in modules_dim:
+ dim = modules_dim[lora_name]
+ alpha = modules_alpha[lora_name]
+ else:
+ # 通常、すべて対象とする
+ if is_linear or is_conv2d_1x1:
+ dim = default_dim if default_dim is not None else self.lora_dim
+ alpha = self.alpha
+ elif self.conv_lora_dim is not None:
+ dim = self.conv_lora_dim
+ alpha = self.conv_alpha
+
+ if dim is None or dim == 0:
+ # skipした情報を出力
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha,
+ dropout=dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ )
+ loras.append(lora)
+
+ if target_replace_mods is None:
+ break # all modules are searched
+ return loras, skipped
+
+ # # create LoRA for text encoder
+ # # it is redundant to create LoRA modules even if they are not used
+
+ self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
+ # skipped_te = []
+ # for i, text_encoder in enumerate(text_encoders):
+ # index = i
+ # if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
+ # break
+ # logger.info(f"create LoRA for Text Encoder {index+1}:")
+ # text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+ # logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
+ # self.text_encoder_loras.extend(text_encoder_loras)
+ # skipped_te += skipped
+
+ # create LoRA for U-Net
+ self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
+ self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules)
+
+ logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.")
+ if verbose:
+ for lora in self.unet_loras:
+ logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
+
+ skipped = skipped_un
+ if verbose and len(skipped) > 0:
+ logger.warning(
+ f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
+ )
+ for name in skipped:
+ logger.info(f"\t{name}")
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ def prepare_network(self, args):
+ """
+ called after the network is created
+ """
+ pass
+
+ def set_multiplier(self, multiplier):
+ self.multiplier = multiplier
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.multiplier = self.multiplier
+
+ def set_enabled(self, is_enabled):
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.enabled = is_enabled
+
+ def load_weights(self, file):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ info = self.load_state_dict(weights_sd, False)
+ return info
+
+ def apply_to(
+ self,
+ text_encoders: Optional[nn.Module],
+ unet: Optional[nn.Module],
+ apply_text_encoder: bool = True,
+ apply_unet: bool = True,
+ ):
+ if apply_text_encoder:
+ logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.apply_to()
+ self.add_module(lora.lora_name, lora)
+
+ # マージできるかどうかを返す
+ def is_mergeable(self):
+ return True
+
+ # TODO refactor to common function with apply_to
+ def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
+ for lora in self.text_encoder_loras + self.unet_loras:
+ sd_for_lora = {}
+ for key in weights_sd.keys():
+ if key.startswith(lora.lora_name):
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+ if len(sd_for_lora) == 0:
+ logger.info(f"no weight for {lora.lora_name}")
+ continue
+ lora.merge_to(sd_for_lora, dtype, device)
+
+ logger.info(f"weights are merged")
+
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
+ self.loraplus_lr_ratio = loraplus_lr_ratio
+
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}")
+ # logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
+
+ def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs):
+ self.requires_grad_(True)
+
+ all_params = []
+ lr_descriptions = []
+
+ def assemble_params(loras, lr, loraplus_ratio):
+ param_groups = {"lora": {}, "plus": {}}
+ for lora in loras:
+ for name, param in lora.named_parameters():
+ if loraplus_ratio is not None and "lora_up" in name:
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
+ else:
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
+
+ params = []
+ descriptions = []
+ for key in param_groups.keys():
+ param_data = {"params": param_groups[key].values()}
+
+ if len(param_data["params"]) == 0:
+ continue
+
+ if lr is not None:
+ if key == "plus":
+ param_data["lr"] = lr * loraplus_ratio
+ else:
+ param_data["lr"] = lr
+
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
+ logger.info("NO LR skipping!")
+ continue
+
+ params.append(param_data)
+ descriptions.append("plus" if key == "plus" else "")
+
+ return params, descriptions
+
+ if self.unet_loras:
+ params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio)
+ all_params.extend(params)
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
+
+ return all_params, lr_descriptions
+
+ def enable_gradient_checkpointing(self):
+ # not supported
+ pass
+
+ def prepare_grad_etc(self, unet):
+ self.requires_grad_(True)
+
+ def on_epoch_start(self, unet):
+ self.train()
+
+ def on_step_start(self):
+ pass
+
+ def get_trainable_params(self):
+ return self.parameters()
+
+ def save_weights(self, file, dtype, metadata):
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+
+ if dtype is not None:
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ state_dict[key] = v
+
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+ from utils import model_utils
+
+ # Precalculate model hashes to save time on indexing
+ if metadata is None:
+ metadata = {}
+ model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ save_file(state_dict, file, metadata)
+ else:
+ torch.save(state_dict, file)
+
+ def backup_weights(self):
+ # 重みのバックアップを行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not hasattr(org_module, "_lora_org_weight"):
+ sd = org_module.state_dict()
+ org_module._lora_org_weight = sd["weight"].detach().clone()
+ org_module._lora_restored = True
+
+ def restore_weights(self):
+ # 重みのリストアを行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not org_module._lora_restored:
+ sd = org_module.state_dict()
+ sd["weight"] = org_module._lora_org_weight
+ org_module.load_state_dict(sd)
+ org_module._lora_restored = True
+
+ def pre_calculation(self):
+ # 事前計算を行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ sd = org_module.state_dict()
+
+ org_weight = sd["weight"]
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+ sd["weight"] = org_weight + lora_weight
+ assert sd["weight"].shape == org_weight.shape
+ org_module.load_state_dict(sd)
+
+ org_module._lora_restored = False
+ lora.enabled = False
+
+ def apply_max_norm_regularization(self, max_norm_value, device):
+ downkeys = []
+ upkeys = []
+ alphakeys = []
+ norms = []
+ keys_scaled = 0
+
+ state_dict = self.state_dict()
+ for key in state_dict.keys():
+ if "lora_down" in key and "weight" in key:
+ downkeys.append(key)
+ upkeys.append(key.replace("lora_down", "lora_up"))
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
+
+ for i in range(len(downkeys)):
+ down = state_dict[downkeys[i]].to(device)
+ up = state_dict[upkeys[i]].to(device)
+ alpha = state_dict[alphakeys[i]].to(device)
+ dim = down.shape[0]
+ scale = alpha / dim
+
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
+ else:
+ updown = up @ down
+
+ updown *= scale
+
+ norm = updown.norm().clamp(min=max_norm_value / 2)
+ desired = torch.clamp(norm, max=max_norm_value)
+ ratio = desired.cpu() / norm.cpu()
+ sqrt_ratio = ratio**0.5
+ if ratio != 1:
+ keys_scaled += 1
+ state_dict[upkeys[i]] *= sqrt_ratio
+ state_dict[downkeys[i]] *= sqrt_ratio
+ scalednorm = updown.norm() * ratio
+ norms.append(scalednorm.item())
+
+ return keys_scaled, sum(norms) / len(norms), max(norms)
+
+
+def create_network_from_weights_hunyuan_video(
+ multiplier: float,
+ weights_sd: Dict[str, torch.Tensor],
+ text_encoders: Optional[List[nn.Module]] = None,
+ unet: Optional[nn.Module] = None,
+ for_inference: bool = False,
+ **kwargs,
+) -> LoRANetwork:
+ return create_network_from_weights(
+ HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
+ )
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(
+ target_replace_modules: List[str],
+ multiplier: float,
+ weights_sd: Dict[str, torch.Tensor],
+ text_encoders: Optional[List[nn.Module]] = None,
+ unet: Optional[nn.Module] = None,
+ for_inference: bool = False,
+ **kwargs,
+) -> LoRANetwork:
+ # get dim/alpha mapping
+ modules_dim = {}
+ modules_alpha = {}
+ for key, value in weights_sd.items():
+ if "." not in key:
+ continue
+
+ lora_name = key.split(".")[0]
+ if "alpha" in key:
+ modules_alpha[lora_name] = value
+ elif "lora_down" in key:
+ dim = value.shape[0]
+ modules_dim[lora_name] = dim
+ # logger.info(lora_name, value.size(), dim)
+
+ module_class = LoRAInfModule if for_inference else LoRAModule
+
+ network = LoRANetwork(
+ target_replace_modules,
+ "lora_unet",
+ text_encoders,
+ unet,
+ multiplier=multiplier,
+ modules_dim=modules_dim,
+ modules_alpha=modules_alpha,
+ module_class=module_class,
+ )
+ return network
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8196990751772bec947a1ee70fa831dbac02aabe
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,18 @@
+accelerate==1.2.1
+av==14.0.1
+bitsandbytes==0.45.0
+diffusers==0.32.1
+einops==0.7.0
+huggingface-hub==0.26.5
+opencv-python==4.10.0.84
+pillow==10.2.0
+safetensors==0.4.5
+toml==0.10.2
+tqdm==4.67.1
+transformers==4.46.3
+voluptuous==0.15.2
+
+# optional dependencies
+# ascii-magic==2.3.0
+# matplotlib==3.10.0
+# tensorboard
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/huggingface_utils.py b/utils/huggingface_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc7bd7dbb2ef70e0b6244b9db686aae00f46408
--- /dev/null
+++ b/utils/huggingface_utils.py
@@ -0,0 +1,89 @@
+import threading
+from typing import Union, BinaryIO
+from huggingface_hub import HfApi
+from pathlib import Path
+import argparse
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def fire_in_thread(f, *args, **kwargs):
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
+
+
+def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
+ api = HfApi(
+ token=token,
+ )
+ try:
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
+ return True
+ except:
+ return False
+
+
+def upload(
+ args: argparse.Namespace,
+ src: Union[str, Path, bytes, BinaryIO],
+ dest_suffix: str = "",
+ force_sync_upload: bool = False,
+):
+ repo_id = args.huggingface_repo_id
+ repo_type = args.huggingface_repo_type
+ token = args.huggingface_token
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
+ api = HfApi(token=token)
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
+ try:
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
+ except Exception as e: # RepositoryNotFoundError or something else
+ logger.error("===========================================")
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
+ logger.error("===========================================")
+
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
+
+ def uploader():
+ try:
+ if is_folder:
+ api.upload_folder(
+ repo_id=repo_id,
+ repo_type=repo_type,
+ folder_path=src,
+ path_in_repo=path_in_repo,
+ )
+ else:
+ api.upload_file(
+ repo_id=repo_id,
+ repo_type=repo_type,
+ path_or_fileobj=src,
+ path_in_repo=path_in_repo,
+ )
+ except Exception as e: # RuntimeError or something else
+ logger.error("===========================================")
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
+ logger.error("===========================================")
+
+ if args.async_upload and not force_sync_upload:
+ fire_in_thread(uploader)
+ else:
+ uploader()
+
+
+def list_dir(
+ repo_id: str,
+ subfolder: str,
+ repo_type: str,
+ revision: str = "main",
+ token: str = None,
+):
+ api = HfApi(
+ token=token,
+ )
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
+ return file_list
diff --git a/utils/model_utils.py b/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5beed8ec4e09f433ba2e84556a6c8f342a2903f5
--- /dev/null
+++ b/utils/model_utils.py
@@ -0,0 +1,151 @@
+import hashlib
+from io import BytesIO
+from typing import Optional
+
+import safetensors.torch
+import torch
+
+
+def model_hash(filename):
+ """Old model hash used by stable-diffusion-webui"""
+ try:
+ with open(filename, "rb") as file:
+ m = hashlib.sha256()
+
+ file.seek(0x100000)
+ m.update(file.read(0x10000))
+ return m.hexdigest()[0:8]
+ except FileNotFoundError:
+ return "NOFILE"
+ except IsADirectoryError: # Linux?
+ return "IsADirectory"
+ except PermissionError: # Windows
+ return "IsADirectory"
+
+
+def calculate_sha256(filename):
+ """New model hash used by stable-diffusion-webui"""
+ try:
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+ except FileNotFoundError:
+ return "NOFILE"
+ except IsADirectoryError: # Linux?
+ return "IsADirectory"
+ except PermissionError: # Windows
+ return "IsADirectory"
+
+
+def addnet_hash_legacy(b):
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+ m = hashlib.sha256()
+
+ b.seek(0x100000)
+ m.update(b.read(0x10000))
+ return m.hexdigest()[0:8]
+
+
+def addnet_hash_safetensors(b):
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def precalculate_safetensors_hashes(tensors, metadata):
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
+ save time on indexing the model later."""
+
+ # Because writing user metadata to the file can change the result of
+ # sd_models.model_hash(), only retain the training metadata for purposes of
+ # calculating the hash, as they are meant to be immutable
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
+
+ bytes = safetensors.torch.save(tensors, metadata)
+ b = BytesIO(bytes)
+
+ model_hash = addnet_hash_safetensors(b)
+ legacy_hash = addnet_hash_legacy(b)
+ return model_hash, legacy_hash
+
+
+def dtype_to_str(dtype: torch.dtype) -> str:
+ # get name of the dtype
+ dtype_name = str(dtype).split(".")[-1]
+ return dtype_name
+
+
+def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
+ """
+ Convert a string to a torch.dtype
+
+ Args:
+ s: string representation of the dtype
+ default_dtype: default dtype to return if s is None
+
+ Returns:
+ torch.dtype: the corresponding torch.dtype
+
+ Raises:
+ ValueError: if the dtype is not supported
+
+ Examples:
+ >>> str_to_dtype("float32")
+ torch.float32
+ >>> str_to_dtype("fp32")
+ torch.float32
+ >>> str_to_dtype("float16")
+ torch.float16
+ >>> str_to_dtype("fp16")
+ torch.float16
+ >>> str_to_dtype("bfloat16")
+ torch.bfloat16
+ >>> str_to_dtype("bf16")
+ torch.bfloat16
+ >>> str_to_dtype("fp8")
+ torch.float8_e4m3fn
+ >>> str_to_dtype("fp8_e4m3fn")
+ torch.float8_e4m3fn
+ >>> str_to_dtype("fp8_e4m3fnuz")
+ torch.float8_e4m3fnuz
+ >>> str_to_dtype("fp8_e5m2")
+ torch.float8_e5m2
+ >>> str_to_dtype("fp8_e5m2fnuz")
+ torch.float8_e5m2fnuz
+ """
+ if s is None:
+ return default_dtype
+ if s in ["bf16", "bfloat16"]:
+ return torch.bfloat16
+ elif s in ["fp16", "float16"]:
+ return torch.float16
+ elif s in ["fp32", "float32", "float"]:
+ return torch.float32
+ elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
+ return torch.float8_e4m3fn
+ elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
+ return torch.float8_e4m3fnuz
+ elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
+ return torch.float8_e5m2
+ elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
+ return torch.float8_e5m2fnuz
+ elif s in ["fp8", "float8"]:
+ return torch.float8_e4m3fn # default fp8
+ else:
+ raise ValueError(f"Unsupported dtype: {s}")
diff --git a/utils/safetensors_utils.py b/utils/safetensors_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b746bc0b7924513616d5dc0c6f9a8ad3f8c37bf9
--- /dev/null
+++ b/utils/safetensors_utils.py
@@ -0,0 +1,191 @@
+import torch
+import json
+import struct
+from typing import Dict, Any, Union, Optional
+
+from safetensors.torch import load_file
+
+
+def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
+ """
+ memory efficient save file
+ """
+
+ _TYPES = {
+ torch.float64: "F64",
+ torch.float32: "F32",
+ torch.float16: "F16",
+ torch.bfloat16: "BF16",
+ torch.int64: "I64",
+ torch.int32: "I32",
+ torch.int16: "I16",
+ torch.int8: "I8",
+ torch.uint8: "U8",
+ torch.bool: "BOOL",
+ getattr(torch, "float8_e5m2", None): "F8_E5M2",
+ getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
+ }
+ _ALIGN = 256
+
+ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
+ validated = {}
+ for key, value in metadata.items():
+ if not isinstance(key, str):
+ raise ValueError(f"Metadata key must be a string, got {type(key)}")
+ if not isinstance(value, str):
+ print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
+ validated[key] = str(value)
+ else:
+ validated[key] = value
+ return validated
+
+ # print(f"Using memory efficient save file: {filename}")
+
+ header = {}
+ offset = 0
+ if metadata:
+ header["__metadata__"] = validate_metadata(metadata)
+ for k, v in tensors.items():
+ if v.numel() == 0: # empty tensor
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
+ else:
+ size = v.numel() * v.element_size()
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
+ offset += size
+
+ hjson = json.dumps(header).encode("utf-8")
+ hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
+
+ with open(filename, "wb") as f:
+ f.write(struct.pack(" Dict[str, str]:
+ return self.header.get("__metadata__", {})
+
+ def get_tensor(self, key):
+ if key not in self.header:
+ raise KeyError(f"Tensor '{key}' not found in the file")
+
+ metadata = self.header[key]
+ offset_start, offset_end = metadata["data_offsets"]
+
+ if offset_start == offset_end:
+ tensor_bytes = None
+ else:
+ # adjust offset by header size
+ self.file.seek(self.header_size + 8 + offset_start)
+ tensor_bytes = self.file.read(offset_end - offset_start)
+
+ return self._deserialize_tensor(tensor_bytes, metadata)
+
+ def _read_header(self):
+ header_size = struct.unpack(" dict[str, torch.Tensor]:
+ if disable_mmap:
+ # return safetensors.torch.load(open(path, "rb").read())
+ # use experimental loader
+ # logger.info(f"Loading without mmap (experimental)")
+ state_dict = {}
+ with MemoryEfficientSafeOpen(path) as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
+ return state_dict
+ else:
+ try:
+ state_dict = load_file(path, device=device)
+ except:
+ state_dict = load_file(path) # prevent device invalid Error
+ if dtype is not None:
+ for key in state_dict.keys():
+ state_dict[key] = state_dict[key].to(dtype=dtype)
+ return state_dict
diff --git a/utils/sai_model_spec.py b/utils/sai_model_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b8d2939f1319b619b92225143a9413170be0712
--- /dev/null
+++ b/utils/sai_model_spec.py
@@ -0,0 +1,263 @@
+# based on https://github.com/Stability-AI/ModelSpec
+import datetime
+import hashlib
+from io import BytesIO
+import os
+from typing import List, Optional, Tuple, Union
+import safetensors
+import logging
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+r"""
+# Metadata Example
+metadata = {
+ # === Must ===
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
+ "modelspec.implementation": "sgm",
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
+ # === Should ===
+ "modelspec.author": "Example Corp", # Your name or company name
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
+ # === Can ===
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
+}
+"""
+
+BASE_METADATA = {
+ # === Must ===
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
+ "modelspec.architecture": None,
+ "modelspec.implementation": None,
+ "modelspec.title": None,
+ "modelspec.resolution": None,
+ # === Should ===
+ "modelspec.description": None,
+ "modelspec.author": None,
+ "modelspec.date": None,
+ # === Can ===
+ "modelspec.license": None,
+ "modelspec.tags": None,
+ "modelspec.merged_from": None,
+ "modelspec.prediction_type": None,
+ "modelspec.timestep_range": None,
+ "modelspec.encoder_layer": None,
+}
+
+# 別に使うやつだけ定義
+MODELSPEC_TITLE = "modelspec.title"
+
+ARCH_HUNYUAN_VIDEO = "hunyuan-video"
+
+ADAPTER_LORA = "lora"
+
+IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo"
+
+PRED_TYPE_EPSILON = "epsilon"
+# PRED_TYPE_V = "v"
+
+
+def load_bytes_in_safetensors(tensors):
+ bytes = safetensors.torch.save(tensors)
+ b = BytesIO(bytes)
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+
+ return b.read()
+
+
+def precalculate_safetensors_hashes(state_dict):
+ # calculate each tensor one by one to reduce memory usage
+ hash_sha256 = hashlib.sha256()
+ for tensor in state_dict.values():
+ single_tensor_sd = {"tensor": tensor}
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
+ hash_sha256.update(bytes_for_tensor)
+
+ return f"0x{hash_sha256.hexdigest()}"
+
+
+def update_hash_sha256(metadata: dict, state_dict: dict):
+ raise NotImplementedError
+
+
+def build_metadata(
+ state_dict: Optional[dict],
+ timestamp: float,
+ title: Optional[str] = None,
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
+ author: Optional[str] = None,
+ description: Optional[str] = None,
+ license: Optional[str] = None,
+ tags: Optional[str] = None,
+ merged_from: Optional[str] = None,
+ timesteps: Optional[Tuple[int, int]] = None,
+):
+ metadata = {}
+ metadata.update(BASE_METADATA)
+
+ # TODO implement if we can calculate hash without loading all tensors
+ # if state_dict is not None:
+ # hash = precalculate_safetensors_hashes(state_dict)
+ # metadata["modelspec.hash_sha256"] = hash
+
+ arch = ARCH_HUNYUAN_VIDEO
+ arch += f"/{ADAPTER_LORA}"
+ metadata["modelspec.architecture"] = arch
+
+ impl = IMPL_HUNYUAN_VIDEO
+ metadata["modelspec.implementation"] = impl
+
+ if title is None:
+ title = "LoRA"
+ title += f"@{timestamp}"
+ metadata[MODELSPEC_TITLE] = title
+
+ if author is not None:
+ metadata["modelspec.author"] = author
+ else:
+ del metadata["modelspec.author"]
+
+ if description is not None:
+ metadata["modelspec.description"] = description
+ else:
+ del metadata["modelspec.description"]
+
+ if merged_from is not None:
+ metadata["modelspec.merged_from"] = merged_from
+ else:
+ del metadata["modelspec.merged_from"]
+
+ if license is not None:
+ metadata["modelspec.license"] = license
+ else:
+ del metadata["modelspec.license"]
+
+ if tags is not None:
+ metadata["modelspec.tags"] = tags
+ else:
+ del metadata["modelspec.tags"]
+
+ # remove microsecond from time
+ int_ts = int(timestamp)
+
+ # time to iso-8601 compliant date
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
+ metadata["modelspec.date"] = date
+
+ if reso is not None:
+ # comma separated to tuple
+ if isinstance(reso, str):
+ reso = tuple(map(int, reso.split(",")))
+ if len(reso) == 1:
+ reso = (reso[0], reso[0])
+ else:
+ # resolution is defined in dataset, so use default
+ reso = (1280, 720)
+ if isinstance(reso, int):
+ reso = (reso, reso)
+
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
+
+ # metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
+ del metadata["modelspec.prediction_type"]
+
+ if timesteps is not None:
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
+ timesteps = (timesteps, timesteps)
+ if len(timesteps) == 1:
+ timesteps = (timesteps[0], timesteps[0])
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
+ else:
+ del metadata["modelspec.timestep_range"]
+
+ # if clip_skip is not None:
+ # metadata["modelspec.encoder_layer"] = f"{clip_skip}"
+ # else:
+ del metadata["modelspec.encoder_layer"]
+
+ # # assert all values are filled
+ # assert all([v is not None for v in metadata.values()]), metadata
+ if not all([v is not None for v in metadata.values()]):
+ logger.error(f"Internal error: some metadata values are None: {metadata}")
+
+ return metadata
+
+
+# region utils
+
+
+def get_title(metadata: dict) -> Optional[str]:
+ return metadata.get(MODELSPEC_TITLE, None)
+
+
+def load_metadata_from_safetensors(model: str) -> dict:
+ if not model.endswith(".safetensors"):
+ return {}
+
+ with safetensors.safe_open(model, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ return metadata
+
+
+def build_merged_from(models: List[str]) -> str:
+ def get_title(model: str):
+ metadata = load_metadata_from_safetensors(model)
+ title = metadata.get(MODELSPEC_TITLE, None)
+ if title is None:
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
+ return title
+
+ titles = [get_title(model) for model in models]
+ return ", ".join(titles)
+
+
+# endregion
+
+
+r"""
+if __name__ == "__main__":
+ import argparse
+ import torch
+ from safetensors.torch import load_file
+ from library import train_util
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt", type=str, required=True)
+ args = parser.parse_args()
+
+ print(f"Loading {args.ckpt}")
+ state_dict = load_file(args.ckpt)
+
+ print(f"Calculating metadata")
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
+ print(metadata)
+ del state_dict
+
+ # by reference implementation
+ with open(args.ckpt, mode="rb") as file_data:
+ file_hash = hashlib.sha256()
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
+ content = (
+ file_data.read()
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
+ file_hash.update(content)
+ # ===== Update the hash for modelspec =====
+ by_ref = f"0x{file_hash.hexdigest()}"
+ print(by_ref)
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
+
+"""
diff --git a/utils/train_utils.py b/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d6af9ae69c5f4748406ae96577f373ae6df5da1
--- /dev/null
+++ b/utils/train_utils.py
@@ -0,0 +1,177 @@
+import argparse
+import logging
+import os
+import shutil
+
+import accelerate
+import torch
+
+from utils import huggingface_utils
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+# checkpointファイル名
+EPOCH_STATE_NAME = "{}-{:06d}-state"
+EPOCH_FILE_NAME = "{}-{:06d}"
+EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
+LAST_STATE_NAME = "{}-state"
+STEP_STATE_NAME = "{}-step{:08d}-state"
+STEP_FILE_NAME = "{}-step{:08d}"
+STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
+
+
+def get_sanitized_config_or_none(args: argparse.Namespace):
+ # if `--log_config` is enabled, return args for logging. if not, return None.
+ # when `--log_config is enabled, filter out sensitive values from args
+ # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe
+
+ if not args.log_config:
+ return None
+
+ sensitive_args = ["wandb_api_key", "huggingface_token"]
+ sensitive_path_args = [
+ "dit",
+ "vae",
+ "text_encoder1",
+ "text_encoder2",
+ "base_weights",
+ "network_weights",
+ "output_dir",
+ "logging_dir",
+ ]
+ filtered_args = {}
+ for k, v in vars(args).items():
+ # filter out sensitive values and convert to string if necessary
+ if k not in sensitive_args + sensitive_path_args:
+ # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
+ if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
+ filtered_args[k] = v
+ # accelerate does not support lists
+ elif isinstance(v, list):
+ filtered_args[k] = f"{v}"
+ # accelerate does not support objects
+ elif isinstance(v, object):
+ filtered_args[k] = f"{v}"
+
+ return filtered_args
+
+
+class LossRecorder:
+ def __init__(self):
+ self.loss_list: list[float] = []
+ self.loss_total: float = 0.0
+
+ def add(self, *, epoch: int, step: int, loss: float) -> None:
+ if epoch == 0:
+ self.loss_list.append(loss)
+ else:
+ while len(self.loss_list) <= step:
+ self.loss_list.append(0.0)
+ self.loss_total -= self.loss_list[step]
+ self.loss_list[step] = loss
+ self.loss_total += loss
+
+ @property
+ def moving_average(self) -> float:
+ return self.loss_total / len(self.loss_list)
+
+
+def get_epoch_ckpt_name(model_name, epoch_no: int):
+ return EPOCH_FILE_NAME.format(model_name, epoch_no) + ".safetensors"
+
+
+def get_step_ckpt_name(model_name, step_no: int):
+ return STEP_FILE_NAME.format(model_name, step_no) + ".safetensors"
+
+
+def get_last_ckpt_name(model_name):
+ return model_name + ".safetensors"
+
+
+def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
+ if args.save_last_n_epochs is None:
+ return None
+
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
+ if remove_epoch_no < 0:
+ return None
+ return remove_epoch_no
+
+
+def get_remove_step_no(args: argparse.Namespace, step_no: int):
+ if args.save_last_n_steps is None:
+ return None
+
+ # calculate the step number to remove from the last_n_steps and save_every_n_steps
+ # e.g. if save_every_n_steps=10, save_last_n_steps=30, at step 50, keep 30 steps and remove step 10
+ remove_step_no = step_no - args.save_last_n_steps - 1
+ remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
+ if remove_step_no < 0:
+ return None
+ return remove_step_no
+
+
+def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator: accelerate.Accelerator, epoch_no: int):
+ model_name = args.output_name
+
+ logger.info("")
+ logger.info(f"saving state at epoch {epoch_no}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
+ accelerator.save_state(state_dir)
+ if args.save_state_to_huggingface:
+ logger.info("uploading state to huggingface.")
+ huggingface_utils.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
+
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
+ if last_n_epochs is not None:
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
+ if os.path.exists(state_dir_old):
+ logger.info(f"removing old state: {state_dir_old}")
+ shutil.rmtree(state_dir_old)
+
+
+def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator: accelerate.Accelerator, step_no: int):
+ model_name = args.output_name
+
+ logger.info("")
+ logger.info(f"saving state at step {step_no}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
+ accelerator.save_state(state_dir)
+ if args.save_state_to_huggingface:
+ logger.info("uploading state to huggingface.")
+ huggingface_utils.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
+
+ last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
+ if last_n_steps is not None:
+ # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
+ remove_step_no = step_no - last_n_steps - 1
+ remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
+
+ if remove_step_no > 0:
+ state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
+ if os.path.exists(state_dir_old):
+ logger.info(f"removing old state: {state_dir_old}")
+ shutil.rmtree(state_dir_old)
+
+
+def save_state_on_train_end(args: argparse.Namespace, accelerator: accelerate.Accelerator):
+ model_name = args.output_name
+
+ logger.info("")
+ logger.info("saving last state.")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
+ accelerator.save_state(state_dir)
+
+ if args.save_state_to_huggingface:
+ logger.info("uploading last state to huggingface.")
+ huggingface_utils.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
+
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000001.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000001.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..4aaaf2562add18470a6f7ef90f49da7c1221822a
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000001.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c83468be7bc357b777fa900e3ed4fd4452a142ba46dde32074a7d7e15ba9695c
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000002.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000002.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a7db7848ab3e209cfb7b9226d6c7946d6c58a1c3
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000002.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a56b9f52c800b05f93a1437475942deae891d04dab83ae2ad34e179606fbdce
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000003.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000003.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..d0c4b5530e896ff898ad3e1c95d734a77ab9ab46
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000003.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24b72675e75de1ef2e06d9a7abfb4bfdd17c70ce26901beccd5085dfffe664d8
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000004.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000004.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3c7f5d28ed1218d6f2822998f26ee45965402e7e
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000004.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f0aa3e1a2610f9d1e80cffe36a022a96d03c6890a47de2cd48e7fef5917546e
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000005.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000005.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..cb38caf0c86212c04907e44a8935ce8773c0f00e
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000005.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:536d31441ec18c50c3db6ab089f53bf14b62ce9ad9256e1fbd64b9e0dd58a8e6
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000006.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000006.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..1749a0390924aeff70baeb03bdcd4f8d5e7d55a6
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000006.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6f51d038ea3d3bbe7e30b780c7e69ebf932029c57daf45f5f2b7802b7617464
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000007.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000007.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..c7fde500e32e6a7e8d3c97873366552a5bcdc52f
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000007.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9fcb428b85762e86071a631e4e2d6b438ac8d157e616db50c05a7ec2efbc67dc
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000008.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000008.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..479fc8ff03f0e7ddab7265c72c6a91cd1d402238
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000008.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57a00a9e3a7f5ebd7de3fe1ae947f73e55c5a0c06f5adb25226d8e606d20992b
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000009.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000009.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..bc76c2618d345df003aac99b0dc060f1a6628aff
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000009.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db9255ec51a3015974ea04de5d309682e38e2130683d9d4f2283aa6cda120021
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000010.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000010.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..ad0c128107c88ac99b2d41e793ab9b4584ef5b2e
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000010.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07ff9643355fddf8f583dc94d73e412223630d24a338c59bfc7d3dc9c8107eca
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000011.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000011.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..410640c1bd2b494c8e6991af373e4967c532b636
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000011.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:272ef2da97d9e142084b5dbef41d419a65c8080c316a46ff9870caf793236fdd
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000012.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000012.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..74c17b7abf24b74925b1c043d6f3d6bb02ffa943
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000012.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0308166fce29bed867ad7c31ac796d5b21264c8c0ba0638a468dea6cec2dfcd
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000013.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000013.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..14f0d884bebacd22edbd2b3ce5168f052c257d33
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000013.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:395d9e12e8ca4731c9ba8fcb95989cbeddfd14acfe59d0c5eed517259ff6beed
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000014.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000014.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..250b8c9c9ed23251b0fce0ca180dac5f3bbd21b1
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000014.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6cd4b94b80d82bfc10627feb2f30e64f734a0cb5421bdcb7c32945b87427aecb
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000015.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000015.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..1269812f1cd6cedc15581ecad5b5486f9ac1bb97
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000015.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d51d043100ad64b2ae16ff2eb8479e599bd2a31d8e43994daa25ad9687b26ef5
+size 322557560
diff --git a/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora.safetensors b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..6bbd06e08cd69694445ef22bae3cf21ab5bbb4b1
--- /dev/null
+++ b/zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18d9619a973101e93df0be764fb09e1ad358b56862ef3afdba02a64a2182f661
+size 322557560