|
|
|
|
|
import ast
|
|
import builtins
|
|
import collections.abc as abc
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from copy import deepcopy
|
|
from dataclasses import is_dataclass
|
|
from typing import List, Tuple, Union
|
|
import cloudpickle
|
|
import yaml
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf, SCMode
|
|
|
|
from detectron2.utils.file_io import PathManager
|
|
from detectron2.utils.registry import _convert_target_to_string
|
|
|
|
__all__ = ["LazyCall", "LazyConfig"]
|
|
|
|
|
|
class LazyCall:
|
|
"""
|
|
Wrap a callable so that when it's called, the call will not be executed,
|
|
but returns a dict that describes the call.
|
|
|
|
LazyCall object has to be called with only keyword arguments. Positional
|
|
arguments are not yet supported.
|
|
|
|
Examples:
|
|
::
|
|
from detectron2.config import instantiate, LazyCall
|
|
|
|
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
|
|
layer_cfg.out_channels = 64 # can edit it afterwards
|
|
layer = instantiate(layer_cfg)
|
|
"""
|
|
|
|
def __init__(self, target):
|
|
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
|
|
raise TypeError(
|
|
f"target of LazyCall must be a callable or defines a callable! Got {target}"
|
|
)
|
|
self._target = target
|
|
|
|
def __call__(self, **kwargs):
|
|
if is_dataclass(self._target):
|
|
|
|
|
|
target = _convert_target_to_string(self._target)
|
|
else:
|
|
target = self._target
|
|
kwargs["_target_"] = target
|
|
|
|
return DictConfig(content=kwargs, flags={"allow_objects": True})
|
|
|
|
|
|
def _visit_dict_config(cfg, func):
|
|
"""
|
|
Apply func recursively to all DictConfig in cfg.
|
|
"""
|
|
if isinstance(cfg, DictConfig):
|
|
func(cfg)
|
|
for v in cfg.values():
|
|
_visit_dict_config(v, func)
|
|
elif isinstance(cfg, ListConfig):
|
|
for v in cfg:
|
|
_visit_dict_config(v, func)
|
|
|
|
|
|
def _validate_py_syntax(filename):
|
|
|
|
with PathManager.open(filename, "r") as f:
|
|
content = f.read()
|
|
try:
|
|
ast.parse(content)
|
|
except SyntaxError as e:
|
|
raise SyntaxError(f"Config file {filename} has syntax error!") from e
|
|
|
|
|
|
def _cast_to_config(obj):
|
|
|
|
if isinstance(obj, dict):
|
|
return DictConfig(obj, flags={"allow_objects": True})
|
|
return obj
|
|
|
|
|
|
_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
|
|
"""
|
|
A namespace to put all imported config into.
|
|
"""
|
|
|
|
|
|
def _random_package_name(filename):
|
|
|
|
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
|
|
|
|
|
|
@contextmanager
|
|
def _patch_import():
|
|
"""
|
|
Enhance relative import statements in config files, so that they:
|
|
1. locate files purely based on relative location, regardless of packages.
|
|
e.g. you can import file without having __init__
|
|
2. do not cache modules globally; modifications of module states has no side effect
|
|
3. support other storage system through PathManager, so config files can be in the cloud
|
|
4. imported dict are turned into omegaconf.DictConfig automatically
|
|
"""
|
|
old_import = builtins.__import__
|
|
|
|
def find_relative_file(original_file, relative_import_path, level):
|
|
|
|
|
|
|
|
relative_import_err = """
|
|
Relative import of directories is not allowed within config files.
|
|
Within a config file, relative import can only import other config files.
|
|
""".replace(
|
|
"\n", " "
|
|
)
|
|
if not len(relative_import_path):
|
|
raise ImportError(relative_import_err)
|
|
|
|
cur_file = os.path.dirname(original_file)
|
|
for _ in range(level - 1):
|
|
cur_file = os.path.dirname(cur_file)
|
|
cur_name = relative_import_path.lstrip(".")
|
|
for part in cur_name.split("."):
|
|
cur_file = os.path.join(cur_file, part)
|
|
if not cur_file.endswith(".py"):
|
|
cur_file += ".py"
|
|
if not PathManager.isfile(cur_file):
|
|
cur_file_no_suffix = cur_file[: -len(".py")]
|
|
if PathManager.isdir(cur_file_no_suffix):
|
|
raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
|
|
else:
|
|
raise ImportError(
|
|
f"Cannot import name {relative_import_path} from "
|
|
f"{original_file}: {cur_file} does not exist."
|
|
)
|
|
return cur_file
|
|
|
|
def new_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
if (
|
|
|
|
level != 0
|
|
and globals is not None
|
|
and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
|
|
):
|
|
cur_file = find_relative_file(globals["__file__"], name, level)
|
|
_validate_py_syntax(cur_file)
|
|
spec = importlib.machinery.ModuleSpec(
|
|
_random_package_name(cur_file), None, origin=cur_file
|
|
)
|
|
module = importlib.util.module_from_spec(spec)
|
|
module.__file__ = cur_file
|
|
with PathManager.open(cur_file) as f:
|
|
content = f.read()
|
|
exec(compile(content, cur_file, "exec"), module.__dict__)
|
|
for name in fromlist:
|
|
val = _cast_to_config(module.__dict__[name])
|
|
module.__dict__[name] = val
|
|
return module
|
|
return old_import(name, globals, locals, fromlist=fromlist, level=level)
|
|
|
|
builtins.__import__ = new_import
|
|
yield new_import
|
|
builtins.__import__ = old_import
|
|
|
|
|
|
class LazyConfig:
|
|
"""
|
|
Provide methods to save, load, and overrides an omegaconf config object
|
|
which may contain definition of lazily-constructed objects.
|
|
"""
|
|
|
|
@staticmethod
|
|
def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
|
|
"""
|
|
Similar to :meth:`load()`, but load path relative to the caller's
|
|
source file.
|
|
|
|
This has the same functionality as a relative import, except that this method
|
|
accepts filename as a string, so more characters are allowed in the filename.
|
|
"""
|
|
caller_frame = inspect.stack()[1]
|
|
caller_fname = caller_frame[0].f_code.co_filename
|
|
assert caller_fname != "<string>", "load_rel Unable to find caller"
|
|
caller_dir = os.path.dirname(caller_fname)
|
|
filename = os.path.join(caller_dir, filename)
|
|
return LazyConfig.load(filename, keys)
|
|
|
|
@staticmethod
|
|
def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
|
|
"""
|
|
Load a config file.
|
|
|
|
Args:
|
|
filename: absolute path or relative path w.r.t. the current working directory
|
|
keys: keys to load and return. If not given, return all keys
|
|
(whose values are config objects) in a dict.
|
|
"""
|
|
has_keys = keys is not None
|
|
filename = filename.replace("/./", "/")
|
|
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
|
|
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
|
|
if filename.endswith(".py"):
|
|
_validate_py_syntax(filename)
|
|
|
|
with _patch_import():
|
|
|
|
module_namespace = {
|
|
"__file__": filename,
|
|
"__package__": _random_package_name(filename),
|
|
}
|
|
with PathManager.open(filename) as f:
|
|
content = f.read()
|
|
|
|
|
|
|
|
exec(compile(content, filename, "exec"), module_namespace)
|
|
|
|
ret = module_namespace
|
|
else:
|
|
with PathManager.open(filename) as f:
|
|
obj = yaml.unsafe_load(f)
|
|
ret = OmegaConf.create(obj, flags={"allow_objects": True})
|
|
|
|
if has_keys:
|
|
if isinstance(keys, str):
|
|
return _cast_to_config(ret[keys])
|
|
else:
|
|
return tuple(_cast_to_config(ret[a]) for a in keys)
|
|
else:
|
|
if filename.endswith(".py"):
|
|
|
|
ret = DictConfig(
|
|
{
|
|
name: _cast_to_config(value)
|
|
for name, value in ret.items()
|
|
if isinstance(value, (DictConfig, ListConfig, dict))
|
|
and not name.startswith("_")
|
|
},
|
|
flags={"allow_objects": True},
|
|
)
|
|
return ret
|
|
|
|
@staticmethod
|
|
def save(cfg, filename: str):
|
|
"""
|
|
Save a config object to a yaml file.
|
|
Note that when the config dictionary contains complex objects (e.g. lambda),
|
|
it can't be saved to yaml. In that case we will print an error and
|
|
attempt to save to a pkl file instead.
|
|
|
|
Args:
|
|
cfg: an omegaconf config object
|
|
filename: yaml file name to save the config file
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
try:
|
|
cfg = deepcopy(cfg)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
|
|
def _replace_type_by_name(x):
|
|
if "_target_" in x and callable(x._target_):
|
|
try:
|
|
x._target_ = _convert_target_to_string(x._target_)
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
_visit_dict_config(cfg, _replace_type_by_name)
|
|
|
|
save_pkl = False
|
|
try:
|
|
dict = OmegaConf.to_container(
|
|
cfg,
|
|
|
|
|
|
resolve=False,
|
|
|
|
|
|
structured_config_mode=SCMode.INSTANTIATE,
|
|
)
|
|
dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
|
|
with PathManager.open(filename, "w") as f:
|
|
f.write(dumped)
|
|
|
|
try:
|
|
_ = yaml.unsafe_load(dumped)
|
|
except Exception:
|
|
logger.warning(
|
|
"The config contains objects that cannot serialize to a valid yaml. "
|
|
f"{filename} is human-readable but cannot be loaded."
|
|
)
|
|
save_pkl = True
|
|
except Exception:
|
|
logger.exception("Unable to serialize the config to yaml. Error:")
|
|
save_pkl = True
|
|
|
|
if save_pkl:
|
|
new_filename = filename + ".pkl"
|
|
try:
|
|
|
|
with PathManager.open(new_filename, "wb") as f:
|
|
cloudpickle.dump(cfg, f)
|
|
logger.warning(f"Config is saved using cloudpickle at {new_filename}.")
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def apply_overrides(cfg, overrides: List[str]):
|
|
"""
|
|
In-place override contents of cfg.
|
|
|
|
Args:
|
|
cfg: an omegaconf config object
|
|
overrides: list of strings in the format of "a=b" to override configs.
|
|
See https://hydra.cc/docs/next/advanced/override_grammar/basic/
|
|
for syntax.
|
|
|
|
Returns:
|
|
the cfg object
|
|
"""
|
|
|
|
def safe_update(cfg, key, value):
|
|
parts = key.split(".")
|
|
for idx in range(1, len(parts)):
|
|
prefix = ".".join(parts[:idx])
|
|
v = OmegaConf.select(cfg, prefix, default=None)
|
|
if v is None:
|
|
break
|
|
if not OmegaConf.is_config(v):
|
|
raise KeyError(
|
|
f"Trying to update key {key}, but {prefix} "
|
|
f"is not a config, but has type {type(v)}."
|
|
)
|
|
OmegaConf.update(cfg, key, value, merge=True)
|
|
|
|
try:
|
|
from hydra.core.override_parser.overrides_parser import OverridesParser
|
|
|
|
has_hydra = True
|
|
except ImportError:
|
|
has_hydra = False
|
|
|
|
if has_hydra:
|
|
parser = OverridesParser.create()
|
|
overrides = parser.parse_overrides(overrides)
|
|
for o in overrides:
|
|
key = o.key_or_group
|
|
value = o.value()
|
|
if o.is_delete():
|
|
|
|
raise NotImplementedError("deletion is not yet a supported override")
|
|
safe_update(cfg, key, value)
|
|
else:
|
|
|
|
for o in overrides:
|
|
key, value = o.split("=")
|
|
try:
|
|
value = eval(value, {})
|
|
except NameError:
|
|
pass
|
|
safe_update(cfg, key, value)
|
|
return cfg
|
|
|
|
@staticmethod
|
|
def to_py(cfg, prefix: str = "cfg."):
|
|
"""
|
|
Try to convert a config object into Python-like psuedo code.
|
|
|
|
Note that perfect conversion is not always possible. So the returned
|
|
results are mainly meant to be human-readable, and not meant to be executed.
|
|
|
|
Args:
|
|
cfg: an omegaconf config object
|
|
prefix: root name for the resulting code (default: "cfg.")
|
|
|
|
|
|
Returns:
|
|
str of formatted Python code
|
|
"""
|
|
import black
|
|
|
|
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
|
|
def _to_str(obj, prefix=None, inside_call=False):
|
|
if prefix is None:
|
|
prefix = []
|
|
if isinstance(obj, abc.Mapping) and "_target_" in obj:
|
|
|
|
target = _convert_target_to_string(obj.pop("_target_"))
|
|
args = []
|
|
for k, v in sorted(obj.items()):
|
|
args.append(f"{k}={_to_str(v, inside_call=True)}")
|
|
args = ", ".join(args)
|
|
call = f"{target}({args})"
|
|
return "".join(prefix) + call
|
|
elif isinstance(obj, abc.Mapping) and not inside_call:
|
|
|
|
|
|
key_list = []
|
|
for k, v in sorted(obj.items()):
|
|
if isinstance(v, abc.Mapping) and "_target_" not in v:
|
|
key_list.append(_to_str(v, prefix=prefix + [k + "."]))
|
|
else:
|
|
key = "".join(prefix) + k
|
|
key_list.append(f"{key}={_to_str(v)}")
|
|
return "\n".join(key_list)
|
|
elif isinstance(obj, abc.Mapping):
|
|
|
|
return (
|
|
"{"
|
|
+ ",".join(
|
|
f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
|
|
for k, v in sorted(obj.items())
|
|
)
|
|
+ "}"
|
|
)
|
|
elif isinstance(obj, list):
|
|
return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]"
|
|
else:
|
|
return repr(obj)
|
|
|
|
py_str = _to_str(cfg, prefix=[prefix])
|
|
try:
|
|
return black.format_str(py_str, mode=black.Mode())
|
|
except black.InvalidInput:
|
|
return py_str
|
|
|