import argparse from dataclasses import ( asdict, dataclass, ) import functools import random from textwrap import dedent, indent import json from pathlib import Path # from toolz import curry from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @dataclass class BaseDatasetParams: resolution: Tuple[int, int] = (960, 544) enable_bucket: bool = False bucket_no_upscale: bool = False caption_extension: Optional[str] = None batch_size: int = 1 cache_directory: Optional[str] = None debug_dataset: bool = False @dataclass class ImageDatasetParams(BaseDatasetParams): image_directory: Optional[str] = None image_jsonl_file: Optional[str] = None @dataclass class VideoDatasetParams(BaseDatasetParams): video_directory: Optional[str] = None video_jsonl_file: Optional[str] = None target_frames: Sequence[int] = (1,) frame_extraction: Optional[str] = "head" frame_stride: Optional[int] = 1 frame_sample: Optional[int] = 1 @dataclass class DatasetBlueprint: is_image_dataset: bool params: Union[ImageDatasetParams, VideoDatasetParams] @dataclass class DatasetGroupBlueprint: datasets: Sequence[DatasetBlueprint] @dataclass class Blueprint: dataset_group: DatasetGroupBlueprint class ConfigSanitizer: # @curry @staticmethod def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: Schema(ExactSequence([klass, klass]))(value) return tuple(value) # @curry @staticmethod def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: Schema(Any(klass, ExactSequence([klass, klass])))(value) try: Schema(klass)(value) return (value, value) except: return ConfigSanitizer.__validate_and_convert_twodim(klass, value) # datasets schema DATASET_ASCENDABLE_SCHEMA = { "caption_extension": str, "batch_size": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "enable_bucket": bool, "bucket_no_upscale": bool, } IMAGE_DATASET_DISTINCT_SCHEMA = { "image_directory": str, "image_jsonl_file": str, "cache_directory": str, } VIDEO_DATASET_DISTINCT_SCHEMA = { "video_directory": str, "video_jsonl_file": str, "target_frames": [int], "frame_extraction": str, "frame_stride": int, "frame_sample": int, "cache_directory": str, } # options handled by argparse but not handled by user config ARGPARSE_SPECIFIC_SCHEMA = { "debug_dataset": bool, } def __init__(self) -> None: self.image_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.IMAGE_DATASET_DISTINCT_SCHEMA, ) self.video_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.VIDEO_DATASET_DISTINCT_SCHEMA, ) def validate_flex_dataset(dataset_config: dict): if "target_frames" in dataset_config: return Schema(self.video_dataset_schema)(dataset_config) else: return Schema(self.image_dataset_schema)(dataset_config) self.dataset_schema = validate_flex_dataset self.general_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, ) self.user_config_validator = Schema( { "general": self.general_schema, "datasets": [self.dataset_schema], } ) self.argparse_schema = self.__merge_dict( self.ARGPARSE_SPECIFIC_SCHEMA, ) self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) def sanitize_user_config(self, user_config: dict) -> dict: try: return self.user_config_validator(user_config) except MultipleInvalid: # TODO: clarify the error message logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") raise # NOTE: In nature, argument parser result is not needed to be sanitize # However this will help us to detect program bug def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: try: return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug logger.error( "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @staticmethod def __merge_dict(*dict_list: dict) -> dict: merged = {} for schema in dict_list: # merged |= schema for k, v in schema.items(): merged[k] = v return merged class BlueprintGenerator: BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} def __init__(self, sanitizer: ConfigSanitizer): self.sanitizer = sanitizer # runtime_params is for parameters which is only configurable on runtime, such as tokenizer def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None} general_config = sanitized_user_config.get("general", {}) dataset_blueprints = [] for dataset_config in sanitized_user_config.get("datasets", []): is_image_dataset = "target_frames" not in dataset_config if is_image_dataset: dataset_params_klass = ImageDatasetParams else: dataset_params_klass = VideoDatasetParams params = self.generate_params_by_fallbacks( dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] ) dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params)) dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) return Blueprint(dataset_group_blueprint) @staticmethod def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME search_value = BlueprintGenerator.search_value default_params = asdict(param_klass()) param_names = default_params.keys() params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} return param_klass(**params) @staticmethod def search_value(key: str, fallbacks: Sequence[dict], default_value=None): for cand in fallbacks: value = cand.get(key) if value is not None: return value return default_value # if training is True, it will return a dataset group for training, otherwise for caching def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup: datasets: List[Union[ImageDataset, VideoDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.is_image_dataset: dataset_klass = ImageDataset else: dataset_klass = VideoDataset dataset = dataset_klass(**asdict(dataset_blueprint.params)) datasets.append(dataset) # print info info = "" for i, dataset in enumerate(datasets): is_image_dataset = isinstance(dataset, ImageDataset) info += dedent( f"""\ [Dataset {i}] is_image_dataset: {is_image_dataset} resolution: {dataset.resolution} batch_size: {dataset.batch_size} caption_extension: "{dataset.caption_extension}" enable_bucket: {dataset.enable_bucket} bucket_no_upscale: {dataset.bucket_no_upscale} cache_directory: "{dataset.cache_directory}" debug_dataset: {dataset.debug_dataset} """ ) if is_image_dataset: info += indent( dedent( f"""\ image_directory: "{dataset.image_directory}" image_jsonl_file: "{dataset.image_jsonl_file}" \n""" ), " ", ) else: info += indent( dedent( f"""\ video_directory: "{dataset.video_directory}" video_jsonl_file: "{dataset.video_jsonl_file}" target_frames: {dataset.target_frames} frame_extraction: {dataset.frame_extraction} frame_stride: {dataset.frame_stride} frame_sample: {dataset.frame_sample} \n""" ), " ", ) logger.info(f"{info}") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): # logger.info(f"[Dataset {i}]") dataset.set_seed(seed) if training: dataset.prepare_for_training() return DatasetGroup(datasets) def load_user_config(file: str) -> dict: file: Path = Path(file) if not file.is_file(): raise ValueError(f"file not found / ファイルが見つかりません: {file}") if file.name.lower().endswith(".json"): try: with open(file, "r") as f: config = json.load(f) except Exception: logger.error( f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: logger.error( f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") return config # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("dataset_config") config_args, remain = parser.parse_known_args() parser = argparse.ArgumentParser() parser.add_argument("--debug_dataset", action="store_true") argparse_namespace = parser.parse_args(remain) logger.info("[argparse_namespace]") logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") logger.info(f"{user_config}") sanitizer = ConfigSanitizer() sanitized_user_config = sanitizer.sanitize_user_config(user_config) logger.info("") logger.info("[sanitized_user_config]") logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") logger.info(f"{blueprint}") dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)