|
import logging |
|
import os |
|
import warnings |
|
from pathlib import Path |
|
from typing import List |
|
|
|
from dotenv import find_dotenv, load_dotenv |
|
from yacs.config import CfgNode |
|
from yacs.config import CfgNode as CN |
|
from yacs.config import _assert_with_logging, _valid_type |
|
|
|
from config import _C |
|
|
|
|
|
def check_cfg(C): |
|
pass |
|
|
|
|
|
def get_cfg_defaults(): |
|
""" |
|
Get a yacs CfgNode object with default values |
|
""" |
|
|
|
|
|
return _C.clone() |
|
|
|
def save_to_yaml(cfg, path_output): |
|
""" |
|
Save the current config to a YAML file. |
|
:param cfg: CfgNode object to be saved |
|
:param path_output: path to output files |
|
""" |
|
path_output = Path(path_output) |
|
path_output.parent.mkdir(parents=True, exist_ok=True) |
|
with open(path_output, "w") as f: |
|
f.write(cfg.dump()) |
|
|
|
def load_from_yaml(path_cfg_data, path_cfg_override=None, list_cfg_override=None): |
|
""" |
|
Load a config from a YAML file. |
|
:param path_cfg_data: path to path_cfg_data files |
|
:param path_cfg_override: path to path_cfg_override actual |
|
:param list_cfg_override: [key1, value1, key2, value2, ...] |
|
:return: cfg_base incorporating the overwrite. |
|
""" |
|
cfg_base = get_cfg_defaults() |
|
if path_cfg_data is not None: |
|
cfg_base.merge_from_file(path_cfg_data) |
|
if path_cfg_override is not None: |
|
cfg_base.merge_from_file(path_cfg_override) |
|
if list_cfg_override is not None: |
|
cfg_base.merge_from_list(list_cfg_override) |
|
return cfg_base |
|
|
|
def convert_to_dict(cfg_node): |
|
def _convert_to_dict(cfg_node, key_list): |
|
_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)} |
|
if not isinstance(cfg_node, CfgNode): |
|
_assert_with_logging( |
|
_valid_type(cfg_node), |
|
"Key {} with value {} is not a valid type; valid types: {}".format( |
|
".".join(key_list), type(cfg_node), _VALID_TYPES |
|
), |
|
) |
|
return cfg_node |
|
else: |
|
cfg_dict = dict(cfg_node) |
|
for k, v in cfg_dict.items(): |
|
cfg_dict[k] = _convert_to_dict(v, key_list + [k]) |
|
return cfg_dict |
|
|
|
return _convert_to_dict(cfg_node, []) |
|
|
|
|
|
def combine_cfgs( |
|
path_cfg_data: Path = None, |
|
path_cfg_override: Path = None, |
|
list_cfg_override: List = None, |
|
): |
|
""" |
|
An internal facing routine thaat combined CFG in the order provided. |
|
:param path_output: path to output files |
|
:param path_cfg_data: path to path_cfg_data files |
|
:param path_cfg_override: path to path_cfg_override actual |
|
:param list_cfg_override: [key1, value1, key2, value2, ...] |
|
:return: cfg_base incorporating the overwrite. |
|
""" |
|
if path_cfg_data is not None: |
|
path_cfg_data = Path(path_cfg_data) |
|
if path_cfg_override is not None: |
|
path_cfg_override = Path(path_cfg_override) |
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg_base = get_cfg_defaults() |
|
|
|
|
|
|
|
if path_cfg_data is not None and path_cfg_data.exists(): |
|
cfg_base.merge_from_file(path_cfg_data.absolute()) |
|
|
|
|
|
|
|
if path_cfg_override is not None and path_cfg_override.exists(): |
|
cfg_base.merge_from_file(path_cfg_override.absolute()) |
|
|
|
|
|
|
|
if list_cfg_override is not None: |
|
cfg_base.merge_from_list(list_cfg_override) |
|
|
|
|
|
|
|
list_cfg = update_cfg_using_dotenv() |
|
if list_cfg is not []: |
|
cfg_base.merge_from_list(list_cfg) |
|
|
|
check_cfg(cfg_base) |
|
|
|
return cfg_base |
|
|
|
|
|
def update_cfg_using_dotenv() -> list: |
|
""" |
|
In case when there are dotenvs, try to return list of them. |
|
# It is returning a list of hard overwrite. |
|
:return: empty list or overwriting information |
|
""" |
|
|
|
if find_dotenv() == "": |
|
warnings.warn(".env files not found. YACS config file merging aborted.") |
|
return [] |
|
|
|
|
|
load_dotenv(find_dotenv(), verbose=True) |
|
|
|
|
|
list_key_env = { |
|
"DATASET.ROOT_DIR", |
|
"DATASET.VOXEL_INDEX_DIR", |
|
"MODEL.BACKBONE.PRETRAINED_WEIGHT_DIR", |
|
"TRAINER.CALLBACKS.CHECKPOINT.ROOT_DIR", |
|
"RESULTS_DIR", |
|
} |
|
|
|
|
|
path_overwrite_keys = [] |
|
logging.info("merge from .env") |
|
|
|
for key in list_key_env: |
|
|
|
|
|
value = os.getenv(key) |
|
logging.info(f"{key}={value}") |
|
|
|
if value is None: |
|
continue |
|
|
|
|
|
path_overwrite_keys.append(key) |
|
path_overwrite_keys.append(value) |
|
|
|
return path_overwrite_keys |
|
|
|
|
|
|
|
|
|
|
|
def flatten_dict(dictionary, level=[]): |
|
"""Flattens a dictionary by placing '.' between levels. |
|
|
|
This function flattens a hierarchical dictionary by placing '.' |
|
between keys at various levels to create a single key for each |
|
value. It is used internally for converting the configuration |
|
dictionary to more convenient formats. Implementation was |
|
inspired by `this StackOverflow post |
|
<https://stackoverflow.com/questions/6037503/python-unflatten-dict>`_. |
|
|
|
Parameters |
|
---------- |
|
dictionary : dict |
|
The hierarchical dictionary to be flattened. |
|
level : str, optional |
|
The string to append to the beginning of this dictionary, |
|
enabling recursive calls. By default, an empty string. |
|
|
|
Returns |
|
------- |
|
dict |
|
The flattened dictionary. |
|
|
|
See Also |
|
-------- |
|
lfads_tf2.utils.unflatten : Performs the opposite of this operation. |
|
|
|
""" |
|
|
|
tmp_dict = {} |
|
for key, val in dictionary.items(): |
|
if type(val) == dict: |
|
tmp_dict.update(flatten_dict(val, level + [key])) |
|
else: |
|
tmp_dict[".".join(level + [key])] = val |
|
return tmp_dict |
|
|
|
|
|
def unflatten(dictionary): |
|
"""Unflattens a dictionary by splitting keys at '.'s. |
|
|
|
This function unflattens a hierarchical dictionary by splitting |
|
its keys at '.'s. It is used internally for converting the |
|
configuration dictionary to more convenient formats. Implementation was |
|
inspired by `this StackOverflow post |
|
<https://stackoverflow.com/questions/6037503/python-unflatten-dict>`_. |
|
|
|
Parameters |
|
---------- |
|
dictionary : dict |
|
The flat dictionary to be unflattened. |
|
|
|
Returns |
|
------- |
|
dict |
|
The unflattened dictionary. |
|
|
|
See Also |
|
-------- |
|
lfads_tf2.utils.flatten : Performs the opposite of this operation. |
|
|
|
""" |
|
|
|
resultDict = dict() |
|
for key, value in dictionary.items(): |
|
parts = key.split(".") |
|
d = resultDict |
|
for part in parts[:-1]: |
|
if part not in d: |
|
d[part] = dict() |
|
d = d[part] |
|
d[parts[-1]] = value |
|
return resultDict |
|
|
|
|
|
def dict_to_list(config): |
|
config_list = [] |
|
for key, val in config.items(): |
|
|
|
config_list.append(key) |
|
config_list.append(val) |
|
return config_list |
|
|