HaMeR / hamer /configs /__init__.py
geopavlakos's picture
Initial commit
d7a991a
raw
history blame
3.13 kB
import os
from typing import Dict
from yacs.config import CfgNode as CN
CACHE_DIR_HAMER = "./_DATA"
def to_lower(x: Dict) -> Dict:
"""
Convert all dictionary keys to lowercase
Args:
x (dict): Input dictionary
Returns:
dict: Output dictionary with all keys converted to lowercase
"""
return {k.lower(): v for k, v in x.items()}
_C = CN(new_allowed=True)
_C.GENERAL = CN(new_allowed=True)
_C.GENERAL.RESUME = True
_C.GENERAL.TIME_TO_RUN = 3300
_C.GENERAL.VAL_STEPS = 100
_C.GENERAL.LOG_STEPS = 100
_C.GENERAL.CHECKPOINT_STEPS = 20000
_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
_C.GENERAL.SUMMARY_DIR = "tensorboard"
_C.GENERAL.NUM_GPUS = 1
_C.GENERAL.NUM_WORKERS = 4
_C.GENERAL.MIXED_PRECISION = True
_C.GENERAL.ALLOW_CUDA = True
_C.GENERAL.PIN_MEMORY = False
_C.GENERAL.DISTRIBUTED = False
_C.GENERAL.LOCAL_RANK = 0
_C.GENERAL.USE_SYNCBN = False
_C.GENERAL.WORLD_SIZE = 1
_C.TRAIN = CN(new_allowed=True)
_C.TRAIN.NUM_EPOCHS = 100
_C.TRAIN.BATCH_SIZE = 32
_C.TRAIN.SHUFFLE = True
_C.TRAIN.WARMUP = False
_C.TRAIN.NORMALIZE_PER_IMAGE = False
_C.TRAIN.CLIP_GRAD = False
_C.TRAIN.CLIP_GRAD_VALUE = 1.0
_C.LOSS_WEIGHTS = CN(new_allowed=True)
_C.DATASETS = CN(new_allowed=True)
_C.MODEL = CN(new_allowed=True)
_C.MODEL.IMAGE_SIZE = 224
_C.EXTRA = CN(new_allowed=True)
_C.EXTRA.FOCAL_LENGTH = 5000
_C.DATASETS.CONFIG = CN(new_allowed=True)
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
_C.DATASETS.CONFIG.ROT_FACTOR = 30
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
_C.DATASETS.CONFIG.DO_FLIP = False
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
def default_config() -> CN:
"""
Get a yacs CfgNode object with the default config values.
"""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()
def dataset_config() -> CN:
"""
Get dataset config file
Returns:
CfgNode: Dataset config as a yacs CfgNode object.
"""
cfg = CN(new_allowed=True)
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
cfg.merge_from_file(config_file)
cfg.freeze()
return cfg
def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
"""
Read a config file and optionally merge it with the default config file.
Args:
config_file (str): Path to config file.
merge (bool): Whether to merge with the default config or not.
Returns:
CfgNode: Config as a yacs CfgNode object.
"""
if merge:
cfg = default_config()
else:
cfg = CN(new_allowed=True)
cfg.merge_from_file(config_file)
if update_cachedir:
def update_path(path: str) -> str:
if os.path.isabs(path):
return path
return os.path.join(CACHE_DIR_HAMER, path)
cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
cfg.freeze()
return cfg