Spaces:
Build error
Build error
import os | |
from typing import Dict | |
from yacs.config import CfgNode as CN | |
CACHE_DIR_HAMER = "./_DATA" | |
def to_lower(x: Dict) -> Dict: | |
""" | |
Convert all dictionary keys to lowercase | |
Args: | |
x (dict): Input dictionary | |
Returns: | |
dict: Output dictionary with all keys converted to lowercase | |
""" | |
return {k.lower(): v for k, v in x.items()} | |
_C = CN(new_allowed=True) | |
_C.GENERAL = CN(new_allowed=True) | |
_C.GENERAL.RESUME = True | |
_C.GENERAL.TIME_TO_RUN = 3300 | |
_C.GENERAL.VAL_STEPS = 100 | |
_C.GENERAL.LOG_STEPS = 100 | |
_C.GENERAL.CHECKPOINT_STEPS = 20000 | |
_C.GENERAL.CHECKPOINT_DIR = "checkpoints" | |
_C.GENERAL.SUMMARY_DIR = "tensorboard" | |
_C.GENERAL.NUM_GPUS = 1 | |
_C.GENERAL.NUM_WORKERS = 4 | |
_C.GENERAL.MIXED_PRECISION = True | |
_C.GENERAL.ALLOW_CUDA = True | |
_C.GENERAL.PIN_MEMORY = False | |
_C.GENERAL.DISTRIBUTED = False | |
_C.GENERAL.LOCAL_RANK = 0 | |
_C.GENERAL.USE_SYNCBN = False | |
_C.GENERAL.WORLD_SIZE = 1 | |
_C.TRAIN = CN(new_allowed=True) | |
_C.TRAIN.NUM_EPOCHS = 100 | |
_C.TRAIN.BATCH_SIZE = 32 | |
_C.TRAIN.SHUFFLE = True | |
_C.TRAIN.WARMUP = False | |
_C.TRAIN.NORMALIZE_PER_IMAGE = False | |
_C.TRAIN.CLIP_GRAD = False | |
_C.TRAIN.CLIP_GRAD_VALUE = 1.0 | |
_C.LOSS_WEIGHTS = CN(new_allowed=True) | |
_C.DATASETS = CN(new_allowed=True) | |
_C.MODEL = CN(new_allowed=True) | |
_C.MODEL.IMAGE_SIZE = 224 | |
_C.EXTRA = CN(new_allowed=True) | |
_C.EXTRA.FOCAL_LENGTH = 5000 | |
_C.DATASETS.CONFIG = CN(new_allowed=True) | |
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 | |
_C.DATASETS.CONFIG.ROT_FACTOR = 30 | |
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 | |
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2 | |
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 | |
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 | |
_C.DATASETS.CONFIG.DO_FLIP = False | |
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 | |
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 | |
def default_config() -> CN: | |
""" | |
Get a yacs CfgNode object with the default config values. | |
""" | |
# Return a clone so that the defaults will not be altered | |
# This is for the "local variable" use pattern | |
return _C.clone() | |
def dataset_config() -> CN: | |
""" | |
Get dataset config file | |
Returns: | |
CfgNode: Dataset config as a yacs CfgNode object. | |
""" | |
cfg = CN(new_allowed=True) | |
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml') | |
cfg.merge_from_file(config_file) | |
cfg.freeze() | |
return cfg | |
def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN: | |
""" | |
Read a config file and optionally merge it with the default config file. | |
Args: | |
config_file (str): Path to config file. | |
merge (bool): Whether to merge with the default config or not. | |
Returns: | |
CfgNode: Config as a yacs CfgNode object. | |
""" | |
if merge: | |
cfg = default_config() | |
else: | |
cfg = CN(new_allowed=True) | |
cfg.merge_from_file(config_file) | |
if update_cachedir: | |
def update_path(path: str) -> str: | |
if os.path.isabs(path): | |
return path | |
return os.path.join(CACHE_DIR_HAMER, path) | |
cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH) | |
cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS) | |
cfg.freeze() | |
return cfg | |