|
import os |
|
import yaml |
|
from yacs.config import CfgNode as CN |
|
|
|
_C = CN() |
|
|
|
|
|
_C.BASE = [''] |
|
|
|
|
|
|
|
|
|
_C.DATA = CN() |
|
|
|
_C.DATA.BATCH_SIZE = 128 |
|
|
|
_C.DATA.DATA_PATHS = [''] |
|
|
|
_C.DATA.DATASET = 'MODIS' |
|
|
|
_C.DATA.IMG_SIZE = 224 |
|
|
|
_C.DATA.INTERPOLATION = 'bicubic' |
|
|
|
_C.DATA.PIN_MEMORY = True |
|
|
|
_C.DATA.NUM_WORKERS = 8 |
|
|
|
_C.DATA.MASK_PATCH_SIZE = 32 |
|
|
|
_C.DATA.MASK_RATIO = 0.6 |
|
|
|
|
|
|
|
|
|
_C.MODEL = CN() |
|
|
|
_C.MODEL.TYPE = 'swinv2' |
|
|
|
_C.MODEL.DECODER = None |
|
|
|
_C.MODEL.NAME = 'swinv2_base_patch4_window7_224' |
|
|
|
|
|
_C.MODEL.PRETRAINED = '' |
|
|
|
_C.MODEL.RESUME = '' |
|
|
|
_C.MODEL.NUM_CLASSES = 17 |
|
|
|
_C.MODEL.DROP_RATE = 0.0 |
|
|
|
_C.MODEL.DROP_PATH_RATE = 0.1 |
|
|
|
|
|
_C.MODEL.SWINV2 = CN() |
|
_C.MODEL.SWINV2.PATCH_SIZE = 4 |
|
_C.MODEL.SWINV2.IN_CHANS = 3 |
|
_C.MODEL.SWINV2.EMBED_DIM = 96 |
|
_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] |
|
_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] |
|
_C.MODEL.SWINV2.WINDOW_SIZE = 7 |
|
_C.MODEL.SWINV2.MLP_RATIO = 4. |
|
_C.MODEL.SWINV2.QKV_BIAS = True |
|
_C.MODEL.SWINV2.APE = False |
|
_C.MODEL.SWINV2.PATCH_NORM = True |
|
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] |
|
|
|
|
|
|
|
|
|
_C.LOSS = CN() |
|
_C.LOSS.NAME = 'tversky' |
|
_C.LOSS.MODE = 'multiclass' |
|
_C.LOSS.CLASSES = None |
|
_C.LOSS.LOG = False |
|
_C.LOSS.LOGITS = True |
|
_C.LOSS.SMOOTH = 0.0 |
|
_C.LOSS.IGNORE_INDEX = None |
|
_C.LOSS.EPS = 1e-7 |
|
_C.LOSS.ALPHA = 0.5 |
|
_C.LOSS.BETA = 0.5 |
|
_C.LOSS.GAMMA = 1.0 |
|
|
|
|
|
|
|
|
|
_C.TRAIN = CN() |
|
_C.TRAIN.START_EPOCH = 0 |
|
_C.TRAIN.EPOCHS = 300 |
|
_C.TRAIN.WARMUP_EPOCHS = 20 |
|
_C.TRAIN.WEIGHT_DECAY = 0.05 |
|
_C.TRAIN.BASE_LR = 5e-4 |
|
_C.TRAIN.WARMUP_LR = 5e-7 |
|
_C.TRAIN.MIN_LR = 5e-6 |
|
|
|
_C.TRAIN.CLIP_GRAD = 5.0 |
|
|
|
_C.TRAIN.AUTO_RESUME = True |
|
|
|
|
|
_C.TRAIN.ACCUMULATION_STEPS = 0 |
|
|
|
|
|
_C.TRAIN.USE_CHECKPOINT = False |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER = CN() |
|
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' |
|
|
|
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 |
|
|
|
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 |
|
|
|
_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 |
|
_C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] |
|
|
|
|
|
_C.TRAIN.OPTIMIZER = CN() |
|
_C.TRAIN.OPTIMIZER.NAME = 'adamw' |
|
|
|
_C.TRAIN.OPTIMIZER.EPS = 1e-8 |
|
|
|
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) |
|
|
|
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 |
|
|
|
|
|
_C.TRAIN.LAYER_DECAY = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
_C.TEST = CN() |
|
|
|
_C.TEST.CROP = True |
|
|
|
|
|
|
|
|
|
|
|
_C.ENABLE_AMP = False |
|
|
|
_C.AMP_ENABLE = True |
|
|
|
_C.OUTPUT = '' |
|
|
|
_C.TAG = 'pt-caney-default-tag' |
|
|
|
_C.SAVE_FREQ = 1 |
|
|
|
_C.PRINT_FREQ = 10 |
|
|
|
_C.SEED = 42 |
|
|
|
_C.EVAL_MODE = False |
|
|
|
|
|
def _update_config_from_file(config, cfg_file): |
|
config.defrost() |
|
with open(cfg_file, 'r') as f: |
|
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
for cfg in yaml_cfg.setdefault('BASE', ['']): |
|
if cfg: |
|
_update_config_from_file( |
|
config, os.path.join(os.path.dirname(cfg_file), cfg) |
|
) |
|
print('=> merge config from {}'.format(cfg_file)) |
|
config.merge_from_file(cfg_file) |
|
config.freeze() |
|
|
|
|
|
def update_config(config, args): |
|
_update_config_from_file(config, args.cfg) |
|
|
|
config.defrost() |
|
|
|
def _check_args(name): |
|
if hasattr(args, name) and eval(f'args.{name}'): |
|
return True |
|
return False |
|
|
|
|
|
if _check_args('batch_size'): |
|
config.DATA.BATCH_SIZE = args.batch_size |
|
if _check_args('data_paths'): |
|
config.DATA.DATA_PATHS = args.data_paths |
|
if _check_args('dataset'): |
|
config.DATA.DATASET = args.dataset |
|
if _check_args('resume'): |
|
config.MODEL.RESUME = args.resume |
|
if _check_args('pretrained'): |
|
config.MODEL.PRETRAINED = args.pretrained |
|
if _check_args('resume'): |
|
config.MODEL.RESUME = args.resume |
|
if _check_args('accumulation_steps'): |
|
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps |
|
if _check_args('use_checkpoint'): |
|
config.TRAIN.USE_CHECKPOINT = True |
|
if _check_args('disable_amp'): |
|
config.AMP_ENABLE = False |
|
if _check_args('output'): |
|
config.OUTPUT = args.output |
|
if _check_args('tag'): |
|
config.TAG = args.tag |
|
if _check_args('eval'): |
|
config.EVAL_MODE = True |
|
if _check_args('enable_amp'): |
|
config.ENABLE_AMP = args.enable_amp |
|
|
|
|
|
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) |
|
|
|
config.freeze() |
|
|
|
|
|
def get_config(args): |
|
"""Get a yacs CfgNode object with default values.""" |
|
|
|
|
|
config = _C.clone() |
|
update_config(config, args) |
|
|
|
return config |
|
|