svjack's picture
Upload folder using huggingface_hub
da486e2 verified
from concurrent.futures import ThreadPoolExecutor
import glob
import json
import math
import os
import random
import time
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
from safetensors.torch import save_file, load_file
from safetensors import safe_open
from PIL import Image
import cv2
import av
from utils import safetensors_utils
from utils.model_utils import dtype_to_str
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
try:
import pillow_avif
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
except:
pass
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
# JPEG-XL on Windows
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] # some of them are not tested
ARCHITECTURE_HUNYUAN_VIDEO = "hv"
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == "*":
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths)) # remove duplicates
img_paths.sort()
return img_paths
def glob_videos(directory, base="*"):
video_paths = []
for ext in VIDEO_EXTENSIONS:
if base == "*":
video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
video_paths = list(set(video_paths)) # remove duplicates
video_paths.sort()
return video_paths
def divisible_by(num: int, divisor: int) -> int:
return num - num % divisor
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
"""
Resize the image to the bucket resolution.
"""
is_pil_image = isinstance(image, Image.Image)
if is_pil_image:
image_width, image_height = image.size
else:
image_height, image_width = image.shape[:2]
if bucket_reso == (image_width, image_height):
return np.array(image) if is_pil_image else image
bucket_width, bucket_height = bucket_reso
if bucket_width == image_width or bucket_height == image_height:
image = np.array(image) if is_pil_image else image
else:
# resize the image to the bucket resolution to match the short side
scale_width = bucket_width / image_width
scale_height = bucket_height / image_height
scale = max(scale_width, scale_height)
image_width = int(image_width * scale + 0.5)
image_height = int(image_height * scale + 0.5)
if scale > 1:
image = Image.fromarray(image) if not is_pil_image else image
image = image.resize((image_width, image_height), Image.LANCZOS)
image = np.array(image)
else:
image = np.array(image) if is_pil_image else image
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
# crop the image to the bucket resolution
crop_left = (image_width - bucket_width) // 2
crop_top = (image_height - bucket_height) // 2
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
return image
class ItemInfo:
def __init__(
self,
item_key: str,
caption: str,
original_size: tuple[int, int],
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
frame_count: Optional[int] = None,
content: Optional[np.ndarray] = None,
latent_cache_path: Optional[str] = None,
) -> None:
self.item_key = item_key
self.caption = caption
self.original_size = original_size
self.bucket_size = bucket_size
self.frame_count = frame_count
self.content = content
self.latent_cache_path = latent_cache_path
self.text_encoder_output_cache_path: Optional[str] = None
def __str__(self) -> str:
return (
f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
+ f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
+ f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
)
def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
metadata = {
"architecture": "hunyuan_video",
"width": f"{item_info.original_size[0]}",
"height": f"{item_info.original_size[1]}",
"format_version": "1.0.0",
}
if item_info.frame_count is not None:
metadata["frame_count"] = f"{item_info.frame_count}"
_, F, H, W = latent.shape
dtype_str = dtype_to_str(latent.dtype)
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
latent_dir = os.path.dirname(item_info.latent_cache_path)
os.makedirs(latent_dir, exist_ok=True)
save_file(sd, item_info.latent_cache_path, metadata=metadata)
def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
assert (
embed.dim() == 1 or embed.dim() == 2
), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
metadata = {
"architecture": "hunyuan_video",
"caption1": item_info.caption,
"format_version": "1.0.0",
}
sd = {}
if os.path.exists(item_info.text_encoder_output_cache_path):
# load existing cache and update metadata
with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
existing_metadata = f.metadata()
for key in f.keys():
sd[key] = f.get_tensor(key)
assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
if existing_metadata["caption1"] != metadata["caption1"]:
logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
# TODO verify format_version
existing_metadata.pop("caption1", None)
existing_metadata.pop("format_version", None)
metadata.update(existing_metadata) # copy existing metadata
else:
text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
os.makedirs(text_encoder_output_dir, exist_ok=True)
dtype_str = dtype_to_str(embed.dtype)
text_encoder_type = "llm" if is_llm else "clipL"
sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
if mask is not None:
sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
class BucketSelector:
RESOLUTION_STEPS_HUNYUAN = 16
def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False):
self.resolution = resolution
self.bucket_area = resolution[0] * resolution[1]
self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
if not enable_bucket:
# only define one bucket
self.bucket_resolutions = [resolution]
self.no_upscale = False
else:
# prepare bucket resolution
self.no_upscale = no_upscale
sqrt_size = int(math.sqrt(self.bucket_area))
min_size = divisible_by(sqrt_size // 2, self.reso_steps)
self.bucket_resolutions = []
for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
h = divisible_by(self.bucket_area // w, self.reso_steps)
self.bucket_resolutions.append((w, h))
self.bucket_resolutions.append((h, w))
self.bucket_resolutions = list(set(self.bucket_resolutions))
self.bucket_resolutions.sort()
# calculate aspect ratio to find the nearest resolution
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
"""
return the bucket resolution for the given image size, (width, height)
"""
area = image_size[0] * image_size[1]
if self.no_upscale and area <= self.bucket_area:
w, h = image_size
w = divisible_by(w, self.reso_steps)
h = divisible_by(h, self.reso_steps)
return w, h
aspect_ratio = image_size[0] / image_size[1]
ar_errors = self.aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
return self.bucket_resolutions[bucket_id]
def load_video(
video_path: str,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> list[np.ndarray]:
container = av.open(video_path)
video = []
bucket_reso = None
for i, frame in enumerate(container.decode(video=0)):
if start_frame is not None and i < start_frame:
continue
if end_frame is not None and i >= end_frame:
break
frame = frame.to_image()
if bucket_selector is not None and bucket_reso is None:
bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
if bucket_reso is not None:
frame = resize_image_to_bucket(frame, bucket_reso)
else:
frame = np.array(frame)
video.append(frame)
container.close()
return video
class BucketBatchManager:
def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
self.batch_size = batch_size
self.buckets = bucketed_item_info
self.bucket_resos = list(self.buckets.keys())
self.bucket_resos.sort()
self.bucket_batch_indices = []
for bucket_reso in self.bucket_resos:
bucket = self.buckets[bucket_reso]
num_batches = math.ceil(len(bucket) / self.batch_size)
for i in range(num_batches):
self.bucket_batch_indices.append((bucket_reso, i))
self.shuffle()
def show_bucket_info(self):
for bucket_reso in self.bucket_resos:
bucket = self.buckets[bucket_reso]
logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
logger.info(f"total batches: {len(self)}")
def shuffle(self):
for bucket in self.buckets.values():
random.shuffle(bucket)
random.shuffle(self.bucket_batch_indices)
def __len__(self):
return len(self.bucket_batch_indices)
def __getitem__(self, idx):
bucket_reso, batch_idx = self.bucket_batch_indices[idx]
bucket = self.buckets[bucket_reso]
start = batch_idx * self.batch_size
end = min(start + self.batch_size, len(bucket))
latents = []
llm_embeds = []
llm_masks = []
clip_l_embeds = []
for item_info in bucket[start:end]:
sd = load_file(item_info.latent_cache_path)
latent = None
for key in sd.keys():
if key.startswith("latents_"):
latent = sd[key]
break
latents.append(latent)
sd = load_file(item_info.text_encoder_output_cache_path)
llm_embed = llm_mask = clip_l_embed = None
for key in sd.keys():
if key.startswith("llm_mask"):
llm_mask = sd[key]
elif key.startswith("llm_"):
llm_embed = sd[key]
elif key.startswith("clipL_mask"):
pass
elif key.startswith("clipL_"):
clip_l_embed = sd[key]
llm_embeds.append(llm_embed)
llm_masks.append(llm_mask)
clip_l_embeds.append(clip_l_embed)
latents = torch.stack(latents)
llm_embeds = torch.stack(llm_embeds)
llm_masks = torch.stack(llm_masks)
clip_l_embeds = torch.stack(clip_l_embeds)
return latents, llm_embeds, llm_masks, clip_l_embeds
class ContentDatasource:
def __init__(self):
self.caption_only = False
def set_caption_only(self, caption_only: bool):
self.caption_only = caption_only
def is_indexable(self):
return False
def get_caption(self, idx: int) -> tuple[str, str]:
"""
Returns caption. May not be called if is_indexable() returns False.
"""
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __iter__(self):
raise NotImplementedError
def __next__(self):
raise NotImplementedError
class ImageDatasource(ContentDatasource):
def __init__(self):
super().__init__()
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
"""
Returns image data as a tuple of image path, image, and caption for the given index.
Key must be unique and valid as a file name.
May not be called if is_indexable() returns False.
"""
raise NotImplementedError
class ImageDirectoryDatasource(ImageDatasource):
def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
super().__init__()
self.image_directory = image_directory
self.caption_extension = caption_extension
self.current_idx = 0
# glob images
logger.info(f"glob images in {self.image_directory}")
self.image_paths = glob_images(self.image_directory)
logger.info(f"found {len(self.image_paths)} images")
def is_indexable(self):
return True
def __len__(self):
return len(self.image_paths)
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")
_, caption = self.get_caption(idx)
return image_path, image, caption
def get_caption(self, idx: int) -> tuple[str, str]:
image_path = self.image_paths[idx]
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
return image_path, caption
def __iter__(self):
self.current_idx = 0
return self
def __next__(self) -> callable:
"""
Returns a fetcher function that returns image data.
"""
if self.current_idx >= len(self.image_paths):
raise StopIteration
if self.caption_only:
def create_caption_fetcher(index):
return lambda: self.get_caption(index)
fetcher = create_caption_fetcher(self.current_idx)
else:
def create_image_fetcher(index):
return lambda: self.get_image_data(index)
fetcher = create_image_fetcher(self.current_idx)
self.current_idx += 1
return fetcher
class ImageJsonlDatasource(ImageDatasource):
def __init__(self, image_jsonl_file: str):
super().__init__()
self.image_jsonl_file = image_jsonl_file
self.current_idx = 0
# load jsonl
logger.info(f"load image jsonl from {self.image_jsonl_file}")
self.data = []
with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
self.data.append(data)
logger.info(f"loaded {len(self.data)} images")
def is_indexable(self):
return True
def __len__(self):
return len(self.data)
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
data = self.data[idx]
image_path = data["image_path"]
image = Image.open(image_path).convert("RGB")
caption = data["caption"]
return image_path, image, caption
def get_caption(self, idx: int) -> tuple[str, str]:
data = self.data[idx]
image_path = data["image_path"]
caption = data["caption"]
return image_path, caption
def __iter__(self):
self.current_idx = 0
return self
def __next__(self) -> callable:
if self.current_idx >= len(self.data):
raise StopIteration
if self.caption_only:
def create_caption_fetcher(index):
return lambda: self.get_caption(index)
fetcher = create_caption_fetcher(self.current_idx)
else:
def create_fetcher(index):
return lambda: self.get_image_data(index)
fetcher = create_fetcher(self.current_idx)
self.current_idx += 1
return fetcher
class VideoDatasource(ContentDatasource):
def __init__(self):
super().__init__()
# None means all frames
self.start_frame = None
self.end_frame = None
self.bucket_selector = None
def __len__(self):
raise NotImplementedError
def get_video_data_from_path(
self,
video_path: str,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> tuple[str, list[Image.Image], str]:
# this method can resize the video if bucket_selector is given to reduce the memory usage
start_frame = start_frame if start_frame is not None else self.start_frame
end_frame = end_frame if end_frame is not None else self.end_frame
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
video = load_video(video_path, start_frame, end_frame, bucket_selector)
return video
def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
self.start_frame = start_frame
self.end_frame = end_frame
def set_bucket_selector(self, bucket_selector: BucketSelector):
self.bucket_selector = bucket_selector
def __iter__(self):
raise NotImplementedError
def __next__(self):
raise NotImplementedError
class VideoDirectoryDatasource(VideoDatasource):
def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
super().__init__()
self.video_directory = video_directory
self.caption_extension = caption_extension
self.current_idx = 0
# glob images
logger.info(f"glob images in {self.video_directory}")
self.video_paths = glob_videos(self.video_directory)
logger.info(f"found {len(self.video_paths)} videos")
def is_indexable(self):
return True
def __len__(self):
return len(self.video_paths)
def get_video_data(
self,
idx: int,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> tuple[str, list[Image.Image], str]:
video_path = self.video_paths[idx]
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
_, caption = self.get_caption(idx)
return video_path, video, caption
def get_caption(self, idx: int) -> tuple[str, str]:
video_path = self.video_paths[idx]
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
return video_path, caption
def __iter__(self):
self.current_idx = 0
return self
def __next__(self):
if self.current_idx >= len(self.video_paths):
raise StopIteration
if self.caption_only:
def create_caption_fetcher(index):
return lambda: self.get_caption(index)
fetcher = create_caption_fetcher(self.current_idx)
else:
def create_fetcher(index):
return lambda: self.get_video_data(index)
fetcher = create_fetcher(self.current_idx)
self.current_idx += 1
return fetcher
class VideoJsonlDatasource(VideoDatasource):
def __init__(self, video_jsonl_file: str):
super().__init__()
self.video_jsonl_file = video_jsonl_file
self.current_idx = 0
# load jsonl
logger.info(f"load video jsonl from {self.video_jsonl_file}")
self.data = []
with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
self.data.append(data)
logger.info(f"loaded {len(self.data)} videos")
def is_indexable(self):
return True
def __len__(self):
return len(self.data)
def get_video_data(
self,
idx: int,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
bucket_selector: Optional[BucketSelector] = None,
) -> tuple[str, list[Image.Image], str]:
data = self.data[idx]
video_path = data["video_path"]
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
caption = data["caption"]
return video_path, video, caption
def get_caption(self, idx: int) -> tuple[str, str]:
data = self.data[idx]
video_path = data["video_path"]
caption = data["caption"]
return video_path, caption
def __iter__(self):
self.current_idx = 0
return self
def __next__(self):
if self.current_idx >= len(self.data):
raise StopIteration
if self.caption_only:
def create_caption_fetcher(index):
return lambda: self.get_caption(index)
fetcher = create_caption_fetcher(self.current_idx)
else:
def create_fetcher(index):
return lambda: self.get_video_data(index)
fetcher = create_fetcher(self.current_idx)
self.current_idx += 1
return fetcher
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self,
resolution: Tuple[int, int] = (960, 544),
caption_extension: Optional[str] = None,
batch_size: int = 1,
enable_bucket: bool = False,
bucket_no_upscale: bool = False,
cache_directory: Optional[str] = None,
debug_dataset: bool = False,
):
self.resolution = resolution
self.caption_extension = caption_extension
self.batch_size = batch_size
self.enable_bucket = enable_bucket
self.bucket_no_upscale = bucket_no_upscale
self.cache_directory = cache_directory
self.debug_dataset = debug_dataset
self.seed = None
self.current_epoch = 0
if not self.enable_bucket:
self.bucket_no_upscale = False
def get_metadata(self) -> dict:
metadata = {
"resolution": self.resolution,
"caption_extension": self.caption_extension,
"batch_size_per_device": self.batch_size,
"enable_bucket": bool(self.enable_bucket),
"bucket_no_upscale": bool(self.bucket_no_upscale),
}
return metadata
def get_latent_cache_path(self, item_info: ItemInfo) -> str:
w, h = item_info.original_size
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")
def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")
def retrieve_latent_cache_batches(self, num_workers: int):
raise NotImplementedError
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
raise NotImplementedError
def prepare_for_training(self):
pass
def set_seed(self, seed: int):
self.seed = seed
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
if epoch > self.current_epoch:
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
num_epochs = epoch - self.current_epoch
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def shuffle_buckets(self):
raise NotImplementedError
def __len__(self):
return NotImplementedError
def __getitem__(self, idx):
raise NotImplementedError
def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
datasource.set_caption_only(True)
executor = ThreadPoolExecutor(max_workers=num_workers)
data: list[ItemInfo] = []
futures = []
def aggregate_future(consume_all: bool = False):
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
completed_futures = [future for future in futures if future.done()]
if len(completed_futures) == 0:
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
time.sleep(0.1)
continue
else:
break # submit batch if possible
for future in completed_futures:
item_key, caption = future.result()
item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
data.append(item_info)
futures.remove(future)
def submit_batch(flush: bool = False):
nonlocal data
if len(data) >= batch_size or (len(data) > 0 and flush):
batch = data[0:batch_size]
if len(data) > batch_size:
data = data[batch_size:]
else:
data = []
return batch
return None
for fetch_op in datasource:
future = executor.submit(fetch_op)
futures.append(future)
aggregate_future()
while True:
batch = submit_batch()
if batch is None:
break
yield batch
aggregate_future(consume_all=True)
while True:
batch = submit_batch(flush=True)
if batch is None:
break
yield batch
executor.shutdown()
class ImageDataset(BaseDataset):
def __init__(
self,
resolution: Tuple[int, int],
caption_extension: Optional[str],
batch_size: int,
enable_bucket: bool,
bucket_no_upscale: bool,
image_directory: Optional[str] = None,
image_jsonl_file: Optional[str] = None,
cache_directory: Optional[str] = None,
debug_dataset: bool = False,
):
super(ImageDataset, self).__init__(
resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
)
self.image_directory = image_directory
self.image_jsonl_file = image_jsonl_file
if image_directory is not None:
self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
elif image_jsonl_file is not None:
self.datasource = ImageJsonlDatasource(image_jsonl_file)
else:
raise ValueError("image_directory or image_jsonl_file must be specified")
if self.cache_directory is None:
self.cache_directory = self.image_directory
self.batch_manager = None
self.num_train_items = 0
def get_metadata(self):
metadata = super().get_metadata()
if self.image_directory is not None:
metadata["image_directory"] = os.path.basename(self.image_directory)
if self.image_jsonl_file is not None:
metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
return metadata
def get_total_image_count(self):
return len(self.datasource) if self.datasource.is_indexable() else None
def retrieve_latent_cache_batches(self, num_workers: int):
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
executor = ThreadPoolExecutor(max_workers=num_workers)
batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
futures = []
def aggregate_future(consume_all: bool = False):
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
completed_futures = [future for future in futures if future.done()]
if len(completed_futures) == 0:
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
time.sleep(0.1)
continue
else:
break # submit batch if possible
for future in completed_futures:
original_size, item_key, image, caption = future.result()
bucket_height, bucket_width = image.shape[:2]
bucket_reso = (bucket_width, bucket_height)
item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
if bucket_reso not in batches:
batches[bucket_reso] = []
batches[bucket_reso].append(item_info)
futures.remove(future)
def submit_batch(flush: bool = False):
for key in batches:
if len(batches[key]) >= self.batch_size or flush:
batch = batches[key][0 : self.batch_size]
if len(batches[key]) > self.batch_size:
batches[key] = batches[key][self.batch_size :]
else:
del batches[key]
return key, batch
return None, None
for fetch_op in self.datasource:
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
image_key, image, caption = op()
image: Image.Image
image_size = image.size
bucket_reso = buckset_selector.get_bucket_resolution(image_size)
image = resize_image_to_bucket(image, bucket_reso)
return image_size, image_key, image, caption
future = executor.submit(fetch_and_resize, fetch_op)
futures.append(future)
aggregate_future()
while True:
key, batch = submit_batch()
if key is None:
break
yield key, batch
aggregate_future(consume_all=True)
while True:
key, batch = submit_batch(flush=True)
if key is None:
break
yield key, batch
executor.shutdown()
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
def prepare_for_training(self):
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
# glob cache files
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
# assign cache files to item info
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
for cache_file in latent_cache_files:
tokens = os.path.basename(cache_file).split("_")
image_size = tokens[-2] # 0000x0000
image_width, image_height = map(int, image_size.split("x"))
image_size = (image_width, image_height)
item_key = "_".join(tokens[:-2])
text_encoder_output_cache_file = os.path.join(
self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
)
if not os.path.exists(text_encoder_output_cache_file):
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
continue
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
bucket = bucketed_item_info.get(bucket_reso, [])
bucket.append(item_info)
bucketed_item_info[bucket_reso] = bucket
# prepare batch manager
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
self.batch_manager.show_bucket_info()
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
self.batch_manager.shuffle()
def __len__(self):
if self.batch_manager is None:
return 100 # dummy value
return len(self.batch_manager)
def __getitem__(self, idx):
return self.batch_manager[idx]
class VideoDataset(BaseDataset):
def __init__(
self,
resolution: Tuple[int, int],
caption_extension: Optional[str],
batch_size: int,
enable_bucket: bool,
bucket_no_upscale: bool,
frame_extraction: Optional[str] = "head",
frame_stride: Optional[int] = 1,
frame_sample: Optional[int] = 1,
target_frames: Optional[list[int]] = None,
video_directory: Optional[str] = None,
video_jsonl_file: Optional[str] = None,
cache_directory: Optional[str] = None,
debug_dataset: bool = False,
):
super(VideoDataset, self).__init__(
resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
)
self.video_directory = video_directory
self.video_jsonl_file = video_jsonl_file
self.target_frames = target_frames
self.frame_extraction = frame_extraction
self.frame_stride = frame_stride
self.frame_sample = frame_sample
if video_directory is not None:
self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
elif video_jsonl_file is not None:
self.datasource = VideoJsonlDatasource(video_jsonl_file)
if self.frame_extraction == "uniform" and self.frame_sample == 1:
self.frame_extraction = "head"
logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
if self.frame_extraction == "head":
# head extraction. we can limit the number of frames to be extracted
self.datasource.set_start_and_end_frame(0, max(self.target_frames))
if self.cache_directory is None:
self.cache_directory = self.video_directory
self.batch_manager = None
self.num_train_items = 0
def get_metadata(self):
metadata = super().get_metadata()
if self.video_directory is not None:
metadata["video_directory"] = os.path.basename(self.video_directory)
if self.video_jsonl_file is not None:
metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
metadata["frame_extraction"] = self.frame_extraction
metadata["frame_stride"] = self.frame_stride
metadata["frame_sample"] = self.frame_sample
metadata["target_frames"] = self.target_frames
return metadata
def retrieve_latent_cache_batches(self, num_workers: int):
buckset_selector = BucketSelector(self.resolution)
self.datasource.set_bucket_selector(buckset_selector)
executor = ThreadPoolExecutor(max_workers=num_workers)
# key: (width, height, frame_count), value: [ItemInfo]
batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
futures = []
def aggregate_future(consume_all: bool = False):
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
completed_futures = [future for future in futures if future.done()]
if len(completed_futures) == 0:
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
time.sleep(0.1)
continue
else:
break # submit batch if possible
for future in completed_futures:
original_frame_size, video_key, video, caption = future.result()
frame_count = len(video)
video = np.stack(video, axis=0)
height, width = video.shape[1:3]
bucket_reso = (width, height) # already resized
crop_pos_and_frames = []
if self.frame_extraction == "head":
for target_frame in self.target_frames:
if frame_count >= target_frame:
crop_pos_and_frames.append((0, target_frame))
elif self.frame_extraction == "chunk":
# split by target_frames
for target_frame in self.target_frames:
for i in range(0, frame_count, target_frame):
if i + target_frame <= frame_count:
crop_pos_and_frames.append((i, target_frame))
elif self.frame_extraction == "slide":
# slide window
for target_frame in self.target_frames:
if frame_count >= target_frame:
for i in range(0, frame_count - target_frame + 1, self.frame_stride):
crop_pos_and_frames.append((i, target_frame))
elif self.frame_extraction == "uniform":
# select N frames uniformly
for target_frame in self.target_frames:
if frame_count >= target_frame:
frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
for i in frame_indices:
crop_pos_and_frames.append((i, target_frame))
else:
raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
for crop_pos, target_frame in crop_pos_and_frames:
cropped_video = video[crop_pos : crop_pos + target_frame]
body, ext = os.path.splitext(video_key)
item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
item_info = ItemInfo(
item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
)
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
batch = batches.get(batch_key, [])
batch.append(item_info)
batches[batch_key] = batch
futures.remove(future)
def submit_batch(flush: bool = False):
for key in batches:
if len(batches[key]) >= self.batch_size or flush:
batch = batches[key][0 : self.batch_size]
if len(batches[key]) > self.batch_size:
batches[key] = batches[key][self.batch_size :]
else:
del batches[key]
return key, batch
return None, None
for operator in self.datasource:
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
video_key, video, caption = op()
video: list[np.ndarray]
frame_size = (video[0].shape[1], video[0].shape[0])
# resize if necessary
bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
return frame_size, video_key, video, caption
future = executor.submit(fetch_and_resize, operator)
futures.append(future)
aggregate_future()
while True:
key, batch = submit_batch()
if key is None:
break
yield key, batch
aggregate_future(consume_all=True)
while True:
key, batch = submit_batch(flush=True)
if key is None:
break
yield key, batch
executor.shutdown()
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
def prepare_for_training(self):
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
# glob cache files
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
# assign cache files to item info
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
for cache_file in latent_cache_files:
tokens = os.path.basename(cache_file).split("_")
image_size = tokens[-2] # 0000x0000
image_width, image_height = map(int, image_size.split("x"))
image_size = (image_width, image_height)
frame_pos, frame_count = tokens[-3].split("-")
frame_pos, frame_count = int(frame_pos), int(frame_count)
item_key = "_".join(tokens[:-3])
text_encoder_output_cache_file = os.path.join(
self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
)
if not os.path.exists(text_encoder_output_cache_file):
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
continue
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
bucket_reso = (*bucket_reso, frame_count)
item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
bucket = bucketed_item_info.get(bucket_reso, [])
bucket.append(item_info)
bucketed_item_info[bucket_reso] = bucket
# prepare batch manager
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
self.batch_manager.show_bucket_info()
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
self.batch_manager.shuffle()
def __len__(self):
if self.batch_manager is None:
return 100 # dummy value
return len(self.batch_manager)
def __getitem__(self, idx):
return self.batch_manager[idx]
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
super().__init__(datasets)
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
self.num_train_items = 0
for dataset in self.datasets:
self.num_train_items += dataset.num_train_items
def set_current_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)