|
from concurrent.futures import ThreadPoolExecutor |
|
import glob |
|
import json |
|
import math |
|
import os |
|
import random |
|
import time |
|
from typing import Optional, Sequence, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from safetensors.torch import save_file, load_file |
|
from safetensors import safe_open |
|
from PIL import Image |
|
import cv2 |
|
import av |
|
|
|
from utils import safetensors_utils |
|
from utils.model_utils import dtype_to_str |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] |
|
|
|
try: |
|
import pillow_avif |
|
|
|
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
from jxlpy import JXLImagePlugin |
|
|
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
import pillow_jxl |
|
|
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
|
except: |
|
pass |
|
|
|
VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] |
|
|
|
ARCHITECTURE_HUNYUAN_VIDEO = "hv" |
|
|
|
|
|
def glob_images(directory, base="*"): |
|
img_paths = [] |
|
for ext in IMAGE_EXTENSIONS: |
|
if base == "*": |
|
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
|
else: |
|
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
|
img_paths = list(set(img_paths)) |
|
img_paths.sort() |
|
return img_paths |
|
|
|
|
|
def glob_videos(directory, base="*"): |
|
video_paths = [] |
|
for ext in VIDEO_EXTENSIONS: |
|
if base == "*": |
|
video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
|
else: |
|
video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
|
video_paths = list(set(video_paths)) |
|
video_paths.sort() |
|
return video_paths |
|
|
|
|
|
def divisible_by(num: int, divisor: int) -> int: |
|
return num - num % divisor |
|
|
|
|
|
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: |
|
""" |
|
Resize the image to the bucket resolution. |
|
""" |
|
is_pil_image = isinstance(image, Image.Image) |
|
if is_pil_image: |
|
image_width, image_height = image.size |
|
else: |
|
image_height, image_width = image.shape[:2] |
|
|
|
if bucket_reso == (image_width, image_height): |
|
return np.array(image) if is_pil_image else image |
|
|
|
bucket_width, bucket_height = bucket_reso |
|
if bucket_width == image_width or bucket_height == image_height: |
|
image = np.array(image) if is_pil_image else image |
|
else: |
|
|
|
scale_width = bucket_width / image_width |
|
scale_height = bucket_height / image_height |
|
scale = max(scale_width, scale_height) |
|
image_width = int(image_width * scale + 0.5) |
|
image_height = int(image_height * scale + 0.5) |
|
|
|
if scale > 1: |
|
image = Image.fromarray(image) if not is_pil_image else image |
|
image = image.resize((image_width, image_height), Image.LANCZOS) |
|
image = np.array(image) |
|
else: |
|
image = np.array(image) if is_pil_image else image |
|
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
crop_left = (image_width - bucket_width) // 2 |
|
crop_top = (image_height - bucket_height) // 2 |
|
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] |
|
return image |
|
|
|
|
|
class ItemInfo: |
|
def __init__( |
|
self, |
|
item_key: str, |
|
caption: str, |
|
original_size: tuple[int, int], |
|
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, |
|
frame_count: Optional[int] = None, |
|
content: Optional[np.ndarray] = None, |
|
latent_cache_path: Optional[str] = None, |
|
) -> None: |
|
self.item_key = item_key |
|
self.caption = caption |
|
self.original_size = original_size |
|
self.bucket_size = bucket_size |
|
self.frame_count = frame_count |
|
self.content = content |
|
self.latent_cache_path = latent_cache_path |
|
self.text_encoder_output_cache_path: Optional[str] = None |
|
|
|
def __str__(self) -> str: |
|
return ( |
|
f"ItemInfo(item_key={self.item_key}, caption={self.caption}, " |
|
+ f"original_size={self.original_size}, bucket_size={self.bucket_size}, " |
|
+ f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})" |
|
) |
|
|
|
|
|
def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor): |
|
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" |
|
metadata = { |
|
"architecture": "hunyuan_video", |
|
"width": f"{item_info.original_size[0]}", |
|
"height": f"{item_info.original_size[1]}", |
|
"format_version": "1.0.0", |
|
} |
|
if item_info.frame_count is not None: |
|
metadata["frame_count"] = f"{item_info.frame_count}" |
|
|
|
_, F, H, W = latent.shape |
|
dtype_str = dtype_to_str(latent.dtype) |
|
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} |
|
|
|
latent_dir = os.path.dirname(item_info.latent_cache_path) |
|
os.makedirs(latent_dir, exist_ok=True) |
|
|
|
save_file(sd, item_info.latent_cache_path, metadata=metadata) |
|
|
|
|
|
def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool): |
|
assert ( |
|
embed.dim() == 1 or embed.dim() == 2 |
|
), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}" |
|
assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}" |
|
metadata = { |
|
"architecture": "hunyuan_video", |
|
"caption1": item_info.caption, |
|
"format_version": "1.0.0", |
|
} |
|
|
|
sd = {} |
|
if os.path.exists(item_info.text_encoder_output_cache_path): |
|
|
|
with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f: |
|
existing_metadata = f.metadata() |
|
for key in f.keys(): |
|
sd[key] = f.get_tensor(key) |
|
|
|
assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch" |
|
if existing_metadata["caption1"] != metadata["caption1"]: |
|
logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite") |
|
|
|
|
|
existing_metadata.pop("caption1", None) |
|
existing_metadata.pop("format_version", None) |
|
metadata.update(existing_metadata) |
|
else: |
|
text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path) |
|
os.makedirs(text_encoder_output_dir, exist_ok=True) |
|
|
|
dtype_str = dtype_to_str(embed.dtype) |
|
text_encoder_type = "llm" if is_llm else "clipL" |
|
sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() |
|
if mask is not None: |
|
sd[f"{text_encoder_type}_mask"] = mask.detach().cpu() |
|
|
|
safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata) |
|
|
|
|
|
class BucketSelector: |
|
RESOLUTION_STEPS_HUNYUAN = 16 |
|
|
|
def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False): |
|
self.resolution = resolution |
|
self.bucket_area = resolution[0] * resolution[1] |
|
self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN |
|
|
|
if not enable_bucket: |
|
|
|
self.bucket_resolutions = [resolution] |
|
self.no_upscale = False |
|
else: |
|
|
|
self.no_upscale = no_upscale |
|
sqrt_size = int(math.sqrt(self.bucket_area)) |
|
min_size = divisible_by(sqrt_size // 2, self.reso_steps) |
|
self.bucket_resolutions = [] |
|
for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): |
|
h = divisible_by(self.bucket_area // w, self.reso_steps) |
|
self.bucket_resolutions.append((w, h)) |
|
self.bucket_resolutions.append((h, w)) |
|
|
|
self.bucket_resolutions = list(set(self.bucket_resolutions)) |
|
self.bucket_resolutions.sort() |
|
|
|
|
|
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) |
|
|
|
def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: |
|
""" |
|
return the bucket resolution for the given image size, (width, height) |
|
""" |
|
area = image_size[0] * image_size[1] |
|
if self.no_upscale and area <= self.bucket_area: |
|
w, h = image_size |
|
w = divisible_by(w, self.reso_steps) |
|
h = divisible_by(h, self.reso_steps) |
|
return w, h |
|
|
|
aspect_ratio = image_size[0] / image_size[1] |
|
ar_errors = self.aspect_ratios - aspect_ratio |
|
bucket_id = np.abs(ar_errors).argmin() |
|
return self.bucket_resolutions[bucket_id] |
|
|
|
|
|
def load_video( |
|
video_path: str, |
|
start_frame: Optional[int] = None, |
|
end_frame: Optional[int] = None, |
|
bucket_selector: Optional[BucketSelector] = None, |
|
) -> list[np.ndarray]: |
|
container = av.open(video_path) |
|
video = [] |
|
bucket_reso = None |
|
for i, frame in enumerate(container.decode(video=0)): |
|
if start_frame is not None and i < start_frame: |
|
continue |
|
if end_frame is not None and i >= end_frame: |
|
break |
|
frame = frame.to_image() |
|
|
|
if bucket_selector is not None and bucket_reso is None: |
|
bucket_reso = bucket_selector.get_bucket_resolution(frame.size) |
|
|
|
if bucket_reso is not None: |
|
frame = resize_image_to_bucket(frame, bucket_reso) |
|
else: |
|
frame = np.array(frame) |
|
|
|
video.append(frame) |
|
container.close() |
|
return video |
|
|
|
|
|
class BucketBatchManager: |
|
|
|
def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int): |
|
self.batch_size = batch_size |
|
self.buckets = bucketed_item_info |
|
self.bucket_resos = list(self.buckets.keys()) |
|
self.bucket_resos.sort() |
|
|
|
self.bucket_batch_indices = [] |
|
for bucket_reso in self.bucket_resos: |
|
bucket = self.buckets[bucket_reso] |
|
num_batches = math.ceil(len(bucket) / self.batch_size) |
|
for i in range(num_batches): |
|
self.bucket_batch_indices.append((bucket_reso, i)) |
|
|
|
self.shuffle() |
|
|
|
def show_bucket_info(self): |
|
for bucket_reso in self.bucket_resos: |
|
bucket = self.buckets[bucket_reso] |
|
logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}") |
|
|
|
logger.info(f"total batches: {len(self)}") |
|
|
|
def shuffle(self): |
|
for bucket in self.buckets.values(): |
|
random.shuffle(bucket) |
|
random.shuffle(self.bucket_batch_indices) |
|
|
|
def __len__(self): |
|
return len(self.bucket_batch_indices) |
|
|
|
def __getitem__(self, idx): |
|
bucket_reso, batch_idx = self.bucket_batch_indices[idx] |
|
bucket = self.buckets[bucket_reso] |
|
start = batch_idx * self.batch_size |
|
end = min(start + self.batch_size, len(bucket)) |
|
|
|
latents = [] |
|
llm_embeds = [] |
|
llm_masks = [] |
|
clip_l_embeds = [] |
|
for item_info in bucket[start:end]: |
|
sd = load_file(item_info.latent_cache_path) |
|
latent = None |
|
for key in sd.keys(): |
|
if key.startswith("latents_"): |
|
latent = sd[key] |
|
break |
|
latents.append(latent) |
|
|
|
sd = load_file(item_info.text_encoder_output_cache_path) |
|
llm_embed = llm_mask = clip_l_embed = None |
|
for key in sd.keys(): |
|
if key.startswith("llm_mask"): |
|
llm_mask = sd[key] |
|
elif key.startswith("llm_"): |
|
llm_embed = sd[key] |
|
elif key.startswith("clipL_mask"): |
|
pass |
|
elif key.startswith("clipL_"): |
|
clip_l_embed = sd[key] |
|
llm_embeds.append(llm_embed) |
|
llm_masks.append(llm_mask) |
|
clip_l_embeds.append(clip_l_embed) |
|
|
|
latents = torch.stack(latents) |
|
llm_embeds = torch.stack(llm_embeds) |
|
llm_masks = torch.stack(llm_masks) |
|
clip_l_embeds = torch.stack(clip_l_embeds) |
|
|
|
return latents, llm_embeds, llm_masks, clip_l_embeds |
|
|
|
|
|
class ContentDatasource: |
|
def __init__(self): |
|
self.caption_only = False |
|
|
|
def set_caption_only(self, caption_only: bool): |
|
self.caption_only = caption_only |
|
|
|
def is_indexable(self): |
|
return False |
|
|
|
def get_caption(self, idx: int) -> tuple[str, str]: |
|
""" |
|
Returns caption. May not be called if is_indexable() returns False. |
|
""" |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
raise NotImplementedError |
|
|
|
def __iter__(self): |
|
raise NotImplementedError |
|
|
|
def __next__(self): |
|
raise NotImplementedError |
|
|
|
|
|
class ImageDatasource(ContentDatasource): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
|
""" |
|
Returns image data as a tuple of image path, image, and caption for the given index. |
|
Key must be unique and valid as a file name. |
|
May not be called if is_indexable() returns False. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class ImageDirectoryDatasource(ImageDatasource): |
|
def __init__(self, image_directory: str, caption_extension: Optional[str] = None): |
|
super().__init__() |
|
self.image_directory = image_directory |
|
self.caption_extension = caption_extension |
|
self.current_idx = 0 |
|
|
|
|
|
logger.info(f"glob images in {self.image_directory}") |
|
self.image_paths = glob_images(self.image_directory) |
|
logger.info(f"found {len(self.image_paths)} images") |
|
|
|
def is_indexable(self): |
|
return True |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
|
image_path = self.image_paths[idx] |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
_, caption = self.get_caption(idx) |
|
|
|
return image_path, image, caption |
|
|
|
def get_caption(self, idx: int) -> tuple[str, str]: |
|
image_path = self.image_paths[idx] |
|
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else "" |
|
with open(caption_path, "r", encoding="utf-8") as f: |
|
caption = f.read().strip() |
|
return image_path, caption |
|
|
|
def __iter__(self): |
|
self.current_idx = 0 |
|
return self |
|
|
|
def __next__(self) -> callable: |
|
""" |
|
Returns a fetcher function that returns image data. |
|
""" |
|
if self.current_idx >= len(self.image_paths): |
|
raise StopIteration |
|
|
|
if self.caption_only: |
|
|
|
def create_caption_fetcher(index): |
|
return lambda: self.get_caption(index) |
|
|
|
fetcher = create_caption_fetcher(self.current_idx) |
|
else: |
|
|
|
def create_image_fetcher(index): |
|
return lambda: self.get_image_data(index) |
|
|
|
fetcher = create_image_fetcher(self.current_idx) |
|
|
|
self.current_idx += 1 |
|
return fetcher |
|
|
|
|
|
class ImageJsonlDatasource(ImageDatasource): |
|
def __init__(self, image_jsonl_file: str): |
|
super().__init__() |
|
self.image_jsonl_file = image_jsonl_file |
|
self.current_idx = 0 |
|
|
|
|
|
logger.info(f"load image jsonl from {self.image_jsonl_file}") |
|
self.data = [] |
|
with open(self.image_jsonl_file, "r", encoding="utf-8") as f: |
|
for line in f: |
|
data = json.loads(line) |
|
self.data.append(data) |
|
logger.info(f"loaded {len(self.data)} images") |
|
|
|
def is_indexable(self): |
|
return True |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
|
data = self.data[idx] |
|
image_path = data["image_path"] |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
caption = data["caption"] |
|
|
|
return image_path, image, caption |
|
|
|
def get_caption(self, idx: int) -> tuple[str, str]: |
|
data = self.data[idx] |
|
image_path = data["image_path"] |
|
caption = data["caption"] |
|
return image_path, caption |
|
|
|
def __iter__(self): |
|
self.current_idx = 0 |
|
return self |
|
|
|
def __next__(self) -> callable: |
|
if self.current_idx >= len(self.data): |
|
raise StopIteration |
|
|
|
if self.caption_only: |
|
|
|
def create_caption_fetcher(index): |
|
return lambda: self.get_caption(index) |
|
|
|
fetcher = create_caption_fetcher(self.current_idx) |
|
|
|
else: |
|
|
|
def create_fetcher(index): |
|
return lambda: self.get_image_data(index) |
|
|
|
fetcher = create_fetcher(self.current_idx) |
|
|
|
self.current_idx += 1 |
|
return fetcher |
|
|
|
|
|
class VideoDatasource(ContentDatasource): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
self.start_frame = None |
|
self.end_frame = None |
|
|
|
self.bucket_selector = None |
|
|
|
def __len__(self): |
|
raise NotImplementedError |
|
|
|
def get_video_data_from_path( |
|
self, |
|
video_path: str, |
|
start_frame: Optional[int] = None, |
|
end_frame: Optional[int] = None, |
|
bucket_selector: Optional[BucketSelector] = None, |
|
) -> tuple[str, list[Image.Image], str]: |
|
|
|
|
|
start_frame = start_frame if start_frame is not None else self.start_frame |
|
end_frame = end_frame if end_frame is not None else self.end_frame |
|
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector |
|
|
|
video = load_video(video_path, start_frame, end_frame, bucket_selector) |
|
return video |
|
|
|
def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]): |
|
self.start_frame = start_frame |
|
self.end_frame = end_frame |
|
|
|
def set_bucket_selector(self, bucket_selector: BucketSelector): |
|
self.bucket_selector = bucket_selector |
|
|
|
def __iter__(self): |
|
raise NotImplementedError |
|
|
|
def __next__(self): |
|
raise NotImplementedError |
|
|
|
|
|
class VideoDirectoryDatasource(VideoDatasource): |
|
def __init__(self, video_directory: str, caption_extension: Optional[str] = None): |
|
super().__init__() |
|
self.video_directory = video_directory |
|
self.caption_extension = caption_extension |
|
self.current_idx = 0 |
|
|
|
|
|
logger.info(f"glob images in {self.video_directory}") |
|
self.video_paths = glob_videos(self.video_directory) |
|
logger.info(f"found {len(self.video_paths)} videos") |
|
|
|
def is_indexable(self): |
|
return True |
|
|
|
def __len__(self): |
|
return len(self.video_paths) |
|
|
|
def get_video_data( |
|
self, |
|
idx: int, |
|
start_frame: Optional[int] = None, |
|
end_frame: Optional[int] = None, |
|
bucket_selector: Optional[BucketSelector] = None, |
|
) -> tuple[str, list[Image.Image], str]: |
|
video_path = self.video_paths[idx] |
|
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) |
|
|
|
_, caption = self.get_caption(idx) |
|
|
|
return video_path, video, caption |
|
|
|
def get_caption(self, idx: int) -> tuple[str, str]: |
|
video_path = self.video_paths[idx] |
|
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else "" |
|
with open(caption_path, "r", encoding="utf-8") as f: |
|
caption = f.read().strip() |
|
return video_path, caption |
|
|
|
def __iter__(self): |
|
self.current_idx = 0 |
|
return self |
|
|
|
def __next__(self): |
|
if self.current_idx >= len(self.video_paths): |
|
raise StopIteration |
|
|
|
if self.caption_only: |
|
|
|
def create_caption_fetcher(index): |
|
return lambda: self.get_caption(index) |
|
|
|
fetcher = create_caption_fetcher(self.current_idx) |
|
|
|
else: |
|
|
|
def create_fetcher(index): |
|
return lambda: self.get_video_data(index) |
|
|
|
fetcher = create_fetcher(self.current_idx) |
|
|
|
self.current_idx += 1 |
|
return fetcher |
|
|
|
|
|
class VideoJsonlDatasource(VideoDatasource): |
|
def __init__(self, video_jsonl_file: str): |
|
super().__init__() |
|
self.video_jsonl_file = video_jsonl_file |
|
self.current_idx = 0 |
|
|
|
|
|
logger.info(f"load video jsonl from {self.video_jsonl_file}") |
|
self.data = [] |
|
with open(self.video_jsonl_file, "r", encoding="utf-8") as f: |
|
for line in f: |
|
data = json.loads(line) |
|
self.data.append(data) |
|
logger.info(f"loaded {len(self.data)} videos") |
|
|
|
def is_indexable(self): |
|
return True |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def get_video_data( |
|
self, |
|
idx: int, |
|
start_frame: Optional[int] = None, |
|
end_frame: Optional[int] = None, |
|
bucket_selector: Optional[BucketSelector] = None, |
|
) -> tuple[str, list[Image.Image], str]: |
|
data = self.data[idx] |
|
video_path = data["video_path"] |
|
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) |
|
|
|
caption = data["caption"] |
|
|
|
return video_path, video, caption |
|
|
|
def get_caption(self, idx: int) -> tuple[str, str]: |
|
data = self.data[idx] |
|
video_path = data["video_path"] |
|
caption = data["caption"] |
|
return video_path, caption |
|
|
|
def __iter__(self): |
|
self.current_idx = 0 |
|
return self |
|
|
|
def __next__(self): |
|
if self.current_idx >= len(self.data): |
|
raise StopIteration |
|
|
|
if self.caption_only: |
|
|
|
def create_caption_fetcher(index): |
|
return lambda: self.get_caption(index) |
|
|
|
fetcher = create_caption_fetcher(self.current_idx) |
|
|
|
else: |
|
|
|
def create_fetcher(index): |
|
return lambda: self.get_video_data(index) |
|
|
|
fetcher = create_fetcher(self.current_idx) |
|
|
|
self.current_idx += 1 |
|
return fetcher |
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
resolution: Tuple[int, int] = (960, 544), |
|
caption_extension: Optional[str] = None, |
|
batch_size: int = 1, |
|
enable_bucket: bool = False, |
|
bucket_no_upscale: bool = False, |
|
cache_directory: Optional[str] = None, |
|
debug_dataset: bool = False, |
|
): |
|
self.resolution = resolution |
|
self.caption_extension = caption_extension |
|
self.batch_size = batch_size |
|
self.enable_bucket = enable_bucket |
|
self.bucket_no_upscale = bucket_no_upscale |
|
self.cache_directory = cache_directory |
|
self.debug_dataset = debug_dataset |
|
self.seed = None |
|
self.current_epoch = 0 |
|
|
|
if not self.enable_bucket: |
|
self.bucket_no_upscale = False |
|
|
|
def get_metadata(self) -> dict: |
|
metadata = { |
|
"resolution": self.resolution, |
|
"caption_extension": self.caption_extension, |
|
"batch_size_per_device": self.batch_size, |
|
"enable_bucket": bool(self.enable_bucket), |
|
"bucket_no_upscale": bool(self.bucket_no_upscale), |
|
} |
|
return metadata |
|
|
|
def get_latent_cache_path(self, item_info: ItemInfo) -> str: |
|
w, h = item_info.original_size |
|
basename = os.path.splitext(os.path.basename(item_info.item_key))[0] |
|
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" |
|
return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors") |
|
|
|
def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str: |
|
basename = os.path.splitext(os.path.basename(item_info.item_key))[0] |
|
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" |
|
return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors") |
|
|
|
def retrieve_latent_cache_batches(self, num_workers: int): |
|
raise NotImplementedError |
|
|
|
def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
|
raise NotImplementedError |
|
|
|
def prepare_for_training(self): |
|
pass |
|
|
|
def set_seed(self, seed: int): |
|
self.seed = seed |
|
|
|
def set_current_epoch(self, epoch): |
|
if not self.current_epoch == epoch: |
|
if epoch > self.current_epoch: |
|
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) |
|
num_epochs = epoch - self.current_epoch |
|
for _ in range(num_epochs): |
|
self.current_epoch += 1 |
|
self.shuffle_buckets() |
|
|
|
else: |
|
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) |
|
self.current_epoch = epoch |
|
|
|
def set_current_step(self, step): |
|
self.current_step = step |
|
|
|
def set_max_train_steps(self, max_train_steps): |
|
self.max_train_steps = max_train_steps |
|
|
|
def shuffle_buckets(self): |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
return NotImplementedError |
|
|
|
def __getitem__(self, idx): |
|
raise NotImplementedError |
|
|
|
def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int): |
|
datasource.set_caption_only(True) |
|
executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
|
data: list[ItemInfo] = [] |
|
futures = [] |
|
|
|
def aggregate_future(consume_all: bool = False): |
|
while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
|
completed_futures = [future for future in futures if future.done()] |
|
if len(completed_futures) == 0: |
|
if len(futures) >= num_workers or consume_all: |
|
time.sleep(0.1) |
|
continue |
|
else: |
|
break |
|
|
|
for future in completed_futures: |
|
item_key, caption = future.result() |
|
item_info = ItemInfo(item_key, caption, (0, 0), (0, 0)) |
|
item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info) |
|
data.append(item_info) |
|
|
|
futures.remove(future) |
|
|
|
def submit_batch(flush: bool = False): |
|
nonlocal data |
|
if len(data) >= batch_size or (len(data) > 0 and flush): |
|
batch = data[0:batch_size] |
|
if len(data) > batch_size: |
|
data = data[batch_size:] |
|
else: |
|
data = [] |
|
return batch |
|
return None |
|
|
|
for fetch_op in datasource: |
|
future = executor.submit(fetch_op) |
|
futures.append(future) |
|
aggregate_future() |
|
while True: |
|
batch = submit_batch() |
|
if batch is None: |
|
break |
|
yield batch |
|
|
|
aggregate_future(consume_all=True) |
|
while True: |
|
batch = submit_batch(flush=True) |
|
if batch is None: |
|
break |
|
yield batch |
|
|
|
executor.shutdown() |
|
|
|
|
|
class ImageDataset(BaseDataset): |
|
def __init__( |
|
self, |
|
resolution: Tuple[int, int], |
|
caption_extension: Optional[str], |
|
batch_size: int, |
|
enable_bucket: bool, |
|
bucket_no_upscale: bool, |
|
image_directory: Optional[str] = None, |
|
image_jsonl_file: Optional[str] = None, |
|
cache_directory: Optional[str] = None, |
|
debug_dataset: bool = False, |
|
): |
|
super(ImageDataset, self).__init__( |
|
resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset |
|
) |
|
self.image_directory = image_directory |
|
self.image_jsonl_file = image_jsonl_file |
|
if image_directory is not None: |
|
self.datasource = ImageDirectoryDatasource(image_directory, caption_extension) |
|
elif image_jsonl_file is not None: |
|
self.datasource = ImageJsonlDatasource(image_jsonl_file) |
|
else: |
|
raise ValueError("image_directory or image_jsonl_file must be specified") |
|
|
|
if self.cache_directory is None: |
|
self.cache_directory = self.image_directory |
|
|
|
self.batch_manager = None |
|
self.num_train_items = 0 |
|
|
|
def get_metadata(self): |
|
metadata = super().get_metadata() |
|
if self.image_directory is not None: |
|
metadata["image_directory"] = os.path.basename(self.image_directory) |
|
if self.image_jsonl_file is not None: |
|
metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file) |
|
return metadata |
|
|
|
def get_total_image_count(self): |
|
return len(self.datasource) if self.datasource.is_indexable() else None |
|
|
|
def retrieve_latent_cache_batches(self, num_workers: int): |
|
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) |
|
executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
|
batches: dict[tuple[int, int], list[ItemInfo]] = {} |
|
futures = [] |
|
|
|
def aggregate_future(consume_all: bool = False): |
|
while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
|
completed_futures = [future for future in futures if future.done()] |
|
if len(completed_futures) == 0: |
|
if len(futures) >= num_workers or consume_all: |
|
time.sleep(0.1) |
|
continue |
|
else: |
|
break |
|
|
|
for future in completed_futures: |
|
original_size, item_key, image, caption = future.result() |
|
bucket_height, bucket_width = image.shape[:2] |
|
bucket_reso = (bucket_width, bucket_height) |
|
|
|
item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image) |
|
item_info.latent_cache_path = self.get_latent_cache_path(item_info) |
|
|
|
if bucket_reso not in batches: |
|
batches[bucket_reso] = [] |
|
batches[bucket_reso].append(item_info) |
|
|
|
futures.remove(future) |
|
|
|
def submit_batch(flush: bool = False): |
|
for key in batches: |
|
if len(batches[key]) >= self.batch_size or flush: |
|
batch = batches[key][0 : self.batch_size] |
|
if len(batches[key]) > self.batch_size: |
|
batches[key] = batches[key][self.batch_size :] |
|
else: |
|
del batches[key] |
|
return key, batch |
|
return None, None |
|
|
|
for fetch_op in self.datasource: |
|
|
|
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]: |
|
image_key, image, caption = op() |
|
image: Image.Image |
|
image_size = image.size |
|
|
|
bucket_reso = buckset_selector.get_bucket_resolution(image_size) |
|
image = resize_image_to_bucket(image, bucket_reso) |
|
return image_size, image_key, image, caption |
|
|
|
future = executor.submit(fetch_and_resize, fetch_op) |
|
futures.append(future) |
|
aggregate_future() |
|
while True: |
|
key, batch = submit_batch() |
|
if key is None: |
|
break |
|
yield key, batch |
|
|
|
aggregate_future(consume_all=True) |
|
while True: |
|
key, batch = submit_batch(flush=True) |
|
if key is None: |
|
break |
|
yield key, batch |
|
|
|
executor.shutdown() |
|
|
|
def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
|
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) |
|
|
|
def prepare_for_training(self): |
|
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) |
|
|
|
|
|
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) |
|
|
|
|
|
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} |
|
for cache_file in latent_cache_files: |
|
tokens = os.path.basename(cache_file).split("_") |
|
|
|
image_size = tokens[-2] |
|
image_width, image_height = map(int, image_size.split("x")) |
|
image_size = (image_width, image_height) |
|
|
|
item_key = "_".join(tokens[:-2]) |
|
text_encoder_output_cache_file = os.path.join( |
|
self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" |
|
) |
|
if not os.path.exists(text_encoder_output_cache_file): |
|
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") |
|
continue |
|
|
|
bucket_reso = bucket_selector.get_bucket_resolution(image_size) |
|
item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file) |
|
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file |
|
|
|
bucket = bucketed_item_info.get(bucket_reso, []) |
|
bucket.append(item_info) |
|
bucketed_item_info[bucket_reso] = bucket |
|
|
|
|
|
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) |
|
self.batch_manager.show_bucket_info() |
|
|
|
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) |
|
|
|
def shuffle_buckets(self): |
|
|
|
random.seed(self.seed + self.current_epoch) |
|
self.batch_manager.shuffle() |
|
|
|
def __len__(self): |
|
if self.batch_manager is None: |
|
return 100 |
|
return len(self.batch_manager) |
|
|
|
def __getitem__(self, idx): |
|
return self.batch_manager[idx] |
|
|
|
|
|
class VideoDataset(BaseDataset): |
|
def __init__( |
|
self, |
|
resolution: Tuple[int, int], |
|
caption_extension: Optional[str], |
|
batch_size: int, |
|
enable_bucket: bool, |
|
bucket_no_upscale: bool, |
|
frame_extraction: Optional[str] = "head", |
|
frame_stride: Optional[int] = 1, |
|
frame_sample: Optional[int] = 1, |
|
target_frames: Optional[list[int]] = None, |
|
video_directory: Optional[str] = None, |
|
video_jsonl_file: Optional[str] = None, |
|
cache_directory: Optional[str] = None, |
|
debug_dataset: bool = False, |
|
): |
|
super(VideoDataset, self).__init__( |
|
resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset |
|
) |
|
self.video_directory = video_directory |
|
self.video_jsonl_file = video_jsonl_file |
|
self.target_frames = target_frames |
|
self.frame_extraction = frame_extraction |
|
self.frame_stride = frame_stride |
|
self.frame_sample = frame_sample |
|
|
|
if video_directory is not None: |
|
self.datasource = VideoDirectoryDatasource(video_directory, caption_extension) |
|
elif video_jsonl_file is not None: |
|
self.datasource = VideoJsonlDatasource(video_jsonl_file) |
|
|
|
if self.frame_extraction == "uniform" and self.frame_sample == 1: |
|
self.frame_extraction = "head" |
|
logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.") |
|
if self.frame_extraction == "head": |
|
|
|
self.datasource.set_start_and_end_frame(0, max(self.target_frames)) |
|
|
|
if self.cache_directory is None: |
|
self.cache_directory = self.video_directory |
|
|
|
self.batch_manager = None |
|
self.num_train_items = 0 |
|
|
|
def get_metadata(self): |
|
metadata = super().get_metadata() |
|
if self.video_directory is not None: |
|
metadata["video_directory"] = os.path.basename(self.video_directory) |
|
if self.video_jsonl_file is not None: |
|
metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file) |
|
metadata["frame_extraction"] = self.frame_extraction |
|
metadata["frame_stride"] = self.frame_stride |
|
metadata["frame_sample"] = self.frame_sample |
|
metadata["target_frames"] = self.target_frames |
|
return metadata |
|
|
|
def retrieve_latent_cache_batches(self, num_workers: int): |
|
buckset_selector = BucketSelector(self.resolution) |
|
self.datasource.set_bucket_selector(buckset_selector) |
|
|
|
executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
|
|
|
batches: dict[tuple[int, int, int], list[ItemInfo]] = {} |
|
futures = [] |
|
|
|
def aggregate_future(consume_all: bool = False): |
|
while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
|
completed_futures = [future for future in futures if future.done()] |
|
if len(completed_futures) == 0: |
|
if len(futures) >= num_workers or consume_all: |
|
time.sleep(0.1) |
|
continue |
|
else: |
|
break |
|
|
|
for future in completed_futures: |
|
original_frame_size, video_key, video, caption = future.result() |
|
|
|
frame_count = len(video) |
|
video = np.stack(video, axis=0) |
|
height, width = video.shape[1:3] |
|
bucket_reso = (width, height) |
|
|
|
crop_pos_and_frames = [] |
|
if self.frame_extraction == "head": |
|
for target_frame in self.target_frames: |
|
if frame_count >= target_frame: |
|
crop_pos_and_frames.append((0, target_frame)) |
|
elif self.frame_extraction == "chunk": |
|
|
|
for target_frame in self.target_frames: |
|
for i in range(0, frame_count, target_frame): |
|
if i + target_frame <= frame_count: |
|
crop_pos_and_frames.append((i, target_frame)) |
|
elif self.frame_extraction == "slide": |
|
|
|
for target_frame in self.target_frames: |
|
if frame_count >= target_frame: |
|
for i in range(0, frame_count - target_frame + 1, self.frame_stride): |
|
crop_pos_and_frames.append((i, target_frame)) |
|
elif self.frame_extraction == "uniform": |
|
|
|
for target_frame in self.target_frames: |
|
if frame_count >= target_frame: |
|
frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int) |
|
for i in frame_indices: |
|
crop_pos_and_frames.append((i, target_frame)) |
|
else: |
|
raise ValueError(f"frame_extraction {self.frame_extraction} is not supported") |
|
|
|
for crop_pos, target_frame in crop_pos_and_frames: |
|
cropped_video = video[crop_pos : crop_pos + target_frame] |
|
body, ext = os.path.splitext(video_key) |
|
item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}" |
|
batch_key = (*bucket_reso, target_frame) |
|
|
|
item_info = ItemInfo( |
|
item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video |
|
) |
|
item_info.latent_cache_path = self.get_latent_cache_path(item_info) |
|
|
|
batch = batches.get(batch_key, []) |
|
batch.append(item_info) |
|
batches[batch_key] = batch |
|
|
|
futures.remove(future) |
|
|
|
def submit_batch(flush: bool = False): |
|
for key in batches: |
|
if len(batches[key]) >= self.batch_size or flush: |
|
batch = batches[key][0 : self.batch_size] |
|
if len(batches[key]) > self.batch_size: |
|
batches[key] = batches[key][self.batch_size :] |
|
else: |
|
del batches[key] |
|
return key, batch |
|
return None, None |
|
|
|
for operator in self.datasource: |
|
|
|
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]: |
|
video_key, video, caption = op() |
|
video: list[np.ndarray] |
|
frame_size = (video[0].shape[1], video[0].shape[0]) |
|
|
|
|
|
bucket_reso = buckset_selector.get_bucket_resolution(frame_size) |
|
video = [resize_image_to_bucket(frame, bucket_reso) for frame in video] |
|
|
|
return frame_size, video_key, video, caption |
|
|
|
future = executor.submit(fetch_and_resize, operator) |
|
futures.append(future) |
|
aggregate_future() |
|
while True: |
|
key, batch = submit_batch() |
|
if key is None: |
|
break |
|
yield key, batch |
|
|
|
aggregate_future(consume_all=True) |
|
while True: |
|
key, batch = submit_batch(flush=True) |
|
if key is None: |
|
break |
|
yield key, batch |
|
|
|
executor.shutdown() |
|
|
|
def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
|
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) |
|
|
|
def prepare_for_training(self): |
|
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) |
|
|
|
|
|
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) |
|
|
|
|
|
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} |
|
for cache_file in latent_cache_files: |
|
tokens = os.path.basename(cache_file).split("_") |
|
|
|
image_size = tokens[-2] |
|
image_width, image_height = map(int, image_size.split("x")) |
|
image_size = (image_width, image_height) |
|
|
|
frame_pos, frame_count = tokens[-3].split("-") |
|
frame_pos, frame_count = int(frame_pos), int(frame_count) |
|
|
|
item_key = "_".join(tokens[:-3]) |
|
text_encoder_output_cache_file = os.path.join( |
|
self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" |
|
) |
|
if not os.path.exists(text_encoder_output_cache_file): |
|
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") |
|
continue |
|
|
|
bucket_reso = bucket_selector.get_bucket_resolution(image_size) |
|
bucket_reso = (*bucket_reso, frame_count) |
|
item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file) |
|
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file |
|
|
|
bucket = bucketed_item_info.get(bucket_reso, []) |
|
bucket.append(item_info) |
|
bucketed_item_info[bucket_reso] = bucket |
|
|
|
|
|
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) |
|
self.batch_manager.show_bucket_info() |
|
|
|
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) |
|
|
|
def shuffle_buckets(self): |
|
|
|
random.seed(self.seed + self.current_epoch) |
|
self.batch_manager.shuffle() |
|
|
|
def __len__(self): |
|
if self.batch_manager is None: |
|
return 100 |
|
return len(self.batch_manager) |
|
|
|
def __getitem__(self, idx): |
|
return self.batch_manager[idx] |
|
|
|
|
|
class DatasetGroup(torch.utils.data.ConcatDataset): |
|
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]): |
|
super().__init__(datasets) |
|
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets |
|
self.num_train_items = 0 |
|
for dataset in self.datasets: |
|
self.num_train_items += dataset.num_train_items |
|
|
|
def set_current_epoch(self, epoch): |
|
for dataset in self.datasets: |
|
dataset.set_current_epoch(epoch) |
|
|
|
def set_current_step(self, step): |
|
for dataset in self.datasets: |
|
dataset.set_current_step(step) |
|
|
|
def set_max_train_steps(self, max_train_steps): |
|
for dataset in self.datasets: |
|
dataset.set_max_train_steps(max_train_steps) |
|
|