DiffusionText2WorldGeneration / config_helper.py
EthanZyh's picture
copied from EthanZyh/DiffusionText2WorldGeneration
8c31d70
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import pkgutil
import sys
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import Any, Dict, Optional
import attr
import attrs
from hydra import compose, initialize
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf
from .log import log
from .config import Config
from .inference import *
def is_attrs_or_dataclass(obj) -> bool:
"""
Check if the object is an instance of an attrs class or a dataclass.
Args:
obj: The object to check.
Returns:
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
"""
return is_dataclass(obj) or attr.has(type(obj))
def get_fields(obj):
"""
Get the fields of an attrs class or a dataclass.
Args:
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
Returns:
list: A list of field names.
Raises:
ValueError: If the object is neither an attrs class nor a dataclass.
"""
if is_dataclass(obj):
return [field.name for field in dataclass_fields(obj)]
elif attr.has(type(obj)):
return [field.name for field in attr.fields(type(obj))]
else:
raise ValueError("The object is neither an attrs class nor a dataclass.")
def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
"""
:param config: the instance of class `Config` (usually from `make_config`)
:param overrides: list of overrides for config
:return: the composed instance of class `Config`
"""
# Store the class of the config for reconstruction after overriding.
# config_class = type(config)
# Convert Config object to a DictConfig object
config_dict = attrs.asdict(config)
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
# Enforce "--" separator between the script arguments and overriding configs.
if overrides:
if overrides[0] != "--":
raise ValueError('Hydra config overrides must be separated with a "--" token.')
overrides = overrides[1:]
# Use Hydra to handle overrides
cs = ConfigStore.instance()
cs.store(name="config", node=config_omegaconf)
with initialize(version_base=None):
config_omegaconf = compose(config_name="config", overrides=overrides)
OmegaConf.resolve(config_omegaconf)
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
"""
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
Args:
ref_instance: The reference instance to determine the type and fields when needed
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
Returns:
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
Raises:
AssertionError: If the fields do not match or if extra keys are found.
Exception: If there is an error constructing the new instance.
"""
is_type = is_attrs_or_dataclass(ref_instance)
if not is_type:
return kwargs
else:
ref_fields = set(get_fields(ref_instance))
assert isinstance(kwargs, dict) or isinstance(
kwargs, DictConfig
), "kwargs must be a dictionary or a DictConfig"
keys = set(kwargs.keys())
# ref_fields must equal to or include all keys
extra_keys = keys - ref_fields
assert ref_fields == keys or keys.issubset(
ref_fields
), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
resolved_kwargs: Dict[str, Any] = {}
for f in keys:
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
try:
new_instance = type(ref_instance)(**resolved_kwargs)
except Exception as e:
log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
log.error(e)
raise e
return new_instance
config = config_from_dict(config, config_omegaconf)
return config
def get_config_module(config_file: str) -> str:
if not config_file.endswith(".py"):
log.error("Config file cannot be specified as module.")
log.error("Please provide the path to the Python config file (relative to the Cosmos root).")
assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found."
# Convert to importable module format.
config_module = config_file.replace("/", ".").replace(".py", "")
return config_module
def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
"""
Import all modules from the specified package path recursively.
This function is typically used in conjunction with Hydra to ensure that all modules
within a specified package are imported, which is necessary for registering configurations.
Example usage:
```python
import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False)
```
Args:
package_path (str): The dotted path to the package from which to import all modules.
reload (bool): Flag to determine whether to reload modules if they're already imported.
skip_underscore (bool): If True, skips importing modules that start with an underscore.
"""
return # we do not use this function
log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
package = importlib.import_module(package_path)
package_directory = package.__path__
def import_modules_recursively(directory: str, prefix: str) -> None:
"""
Recursively imports or reloads all modules in the given directory.
Args:
directory (str): The file system path to the current package directory.
prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config').
"""
for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
if skip_underscore and module_name.startswith("_"):
log.debug(f"Skipping module {module_name} as it starts with an underscore")
continue
full_module_name = f"{prefix}.{module_name}"
log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
if full_module_name in sys.modules and reload:
importlib.reload(sys.modules[full_module_name])
else:
importlib.import_module(full_module_name)
if is_pkg:
sub_package_directory = os.path.join(directory, module_name)
import_modules_recursively(sub_package_directory, full_module_name)
for directory in package_directory:
import_modules_recursively(directory, package_path)