|
from functools import partial |
|
from torch import optim as optim |
|
|
|
import os |
|
import torch |
|
import torch.distributed as dist |
|
import numpy as np |
|
from scipy import interpolate |
|
|
|
|
|
def build_optimizer(config, model, is_pretrain=False, logger=None): |
|
""" |
|
Build optimizer, set weight decay of normalization to 0 by default. |
|
AdamW only. |
|
""" |
|
logger.info('>>>>>>>>>> Build Optimizer') |
|
|
|
skip = {} |
|
|
|
skip_keywords = {} |
|
|
|
if hasattr(model, 'no_weight_decay'): |
|
skip = model.no_weight_decay() |
|
|
|
if hasattr(model, 'no_weight_decay_keywords'): |
|
skip_keywords = model.no_weight_decay_keywords() |
|
|
|
if is_pretrain: |
|
parameters = get_pretrain_param_groups(model, skip, skip_keywords) |
|
|
|
else: |
|
|
|
depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ |
|
else config.MODEL.SWINV2.DEPTHS |
|
|
|
num_layers = sum(depths) |
|
|
|
get_layer_func = partial(get_swin_layer, |
|
num_layers=num_layers + 2, |
|
depths=depths) |
|
|
|
scales = list(config.TRAIN.LAYER_DECAY ** i for i in |
|
reversed(range(num_layers + 2))) |
|
|
|
parameters = get_finetune_param_groups(model, |
|
config.TRAIN.BASE_LR, |
|
config.TRAIN.WEIGHT_DECAY, |
|
get_layer_func, |
|
scales, |
|
skip, |
|
skip_keywords) |
|
|
|
optimizer = None |
|
|
|
optimizer = optim.AdamW(parameters, |
|
eps=config.TRAIN.OPTIMIZER.EPS, |
|
betas=config.TRAIN.OPTIMIZER.BETAS, |
|
lr=config.TRAIN.BASE_LR, |
|
weight_decay=config.TRAIN.WEIGHT_DECAY) |
|
|
|
logger.info(optimizer) |
|
|
|
return optimizer |
|
|
|
|
|
def set_weight_decay(model, skip_list=(), skip_keywords=()): |
|
""" |
|
|
|
Args: |
|
model (_type_): _description_ |
|
skip_list (tuple, optional): _description_. Defaults to (). |
|
skip_keywords (tuple, optional): _description_. Defaults to (). |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
has_decay = [] |
|
|
|
no_decay = [] |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
if not param.requires_grad: |
|
|
|
continue |
|
|
|
if len(param.shape) == 1 or name.endswith(".bias") \ |
|
or (name in skip_list) or \ |
|
check_keywords_in_name(name, skip_keywords): |
|
|
|
no_decay.append(param) |
|
|
|
else: |
|
|
|
has_decay.append(param) |
|
|
|
return [{'params': has_decay}, |
|
{'params': no_decay, 'weight_decay': 0.}] |
|
|
|
|
|
def check_keywords_in_name(name, keywords=()): |
|
|
|
isin = False |
|
|
|
for keyword in keywords: |
|
|
|
if keyword in name: |
|
|
|
isin = True |
|
|
|
return isin |
|
|
|
|
|
def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): |
|
|
|
has_decay = [] |
|
|
|
no_decay = [] |
|
|
|
has_decay_name = [] |
|
|
|
no_decay_name = [] |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
if not param.requires_grad: |
|
|
|
continue |
|
|
|
if len(param.shape) == 1 or name.endswith(".bias") or \ |
|
(name in skip_list) or \ |
|
check_keywords_in_name(name, skip_keywords): |
|
|
|
no_decay.append(param) |
|
|
|
no_decay_name.append(name) |
|
|
|
else: |
|
|
|
has_decay.append(param) |
|
|
|
has_decay_name.append(name) |
|
|
|
return [{'params': has_decay}, |
|
{'params': no_decay, 'weight_decay': 0.}] |
|
|
|
|
|
def get_swin_layer(name, num_layers, depths): |
|
|
|
if name in ("mask_token"): |
|
|
|
return 0 |
|
|
|
elif name.startswith("patch_embed"): |
|
|
|
return 0 |
|
|
|
elif name.startswith("layers"): |
|
|
|
layer_id = int(name.split('.')[1]) |
|
|
|
block_id = name.split('.')[3] |
|
|
|
if block_id == 'reduction' or block_id == 'norm': |
|
|
|
return sum(depths[:layer_id + 1]) |
|
|
|
layer_id = sum(depths[:layer_id]) + int(block_id) |
|
|
|
return layer_id + 1 |
|
|
|
else: |
|
|
|
return num_layers - 1 |
|
|
|
|
|
def get_finetune_param_groups(model, |
|
lr, |
|
weight_decay, |
|
get_layer_func, |
|
scales, |
|
skip_list=(), |
|
skip_keywords=()): |
|
|
|
parameter_group_names = {} |
|
|
|
parameter_group_vars = {} |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
if not param.requires_grad: |
|
|
|
continue |
|
|
|
if len(param.shape) == 1 or name.endswith(".bias") \ |
|
or (name in skip_list) or \ |
|
check_keywords_in_name(name, skip_keywords): |
|
|
|
group_name = "no_decay" |
|
|
|
this_weight_decay = 0. |
|
|
|
else: |
|
|
|
group_name = "decay" |
|
|
|
this_weight_decay = weight_decay |
|
|
|
if get_layer_func is not None: |
|
|
|
layer_id = get_layer_func(name) |
|
|
|
group_name = "layer_%d_%s" % (layer_id, group_name) |
|
|
|
else: |
|
|
|
layer_id = None |
|
|
|
if group_name not in parameter_group_names: |
|
|
|
if scales is not None: |
|
|
|
scale = scales[layer_id] |
|
|
|
else: |
|
|
|
scale = 1. |
|
|
|
parameter_group_names[group_name] = { |
|
"group_name": group_name, |
|
"weight_decay": this_weight_decay, |
|
"params": [], |
|
"lr": lr * scale, |
|
"lr_scale": scale, |
|
} |
|
|
|
parameter_group_vars[group_name] = { |
|
"group_name": group_name, |
|
"weight_decay": this_weight_decay, |
|
"params": [], |
|
"lr": lr * scale, |
|
"lr_scale": scale |
|
} |
|
|
|
parameter_group_vars[group_name]["params"].append(param) |
|
|
|
parameter_group_names[group_name]["params"].append(name) |
|
|
|
return list(parameter_group_vars.values()) |
|
|
|
|
|
def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): |
|
|
|
logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") |
|
|
|
if config.MODEL.RESUME.startswith('https'): |
|
|
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
config.MODEL.RESUME, map_location='cpu', check_hash=True) |
|
|
|
else: |
|
|
|
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
|
|
|
|
|
rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] |
|
|
|
for k in rpe_mlp_keys: |
|
|
|
checkpoint['model'][k.replace( |
|
'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) |
|
|
|
msg = model.load_state_dict(checkpoint['model'], strict=False) |
|
|
|
logger.info(msg) |
|
|
|
max_accuracy = 0.0 |
|
|
|
if not config.EVAL_MODE and 'optimizer' in checkpoint \ |
|
and 'lr_scheduler' in checkpoint \ |
|
and 'scaler' in checkpoint \ |
|
and 'epoch' in checkpoint: |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
|
scaler.load_state_dict(checkpoint['scaler']) |
|
|
|
config.defrost() |
|
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
|
config.freeze() |
|
|
|
logger.info( |
|
f"=> loaded successfully '{config.MODEL.RESUME}' " + |
|
f"(epoch {checkpoint['epoch']})") |
|
|
|
if 'max_accuracy' in checkpoint: |
|
max_accuracy = checkpoint['max_accuracy'] |
|
|
|
else: |
|
max_accuracy = 0.0 |
|
|
|
del checkpoint |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return max_accuracy |
|
|
|
|
|
def save_checkpoint(config, epoch, model, max_accuracy, |
|
optimizer, lr_scheduler, scaler, logger): |
|
|
|
save_state = {'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'scaler': scaler.state_dict(), |
|
'max_accuracy': max_accuracy, |
|
'epoch': epoch, |
|
'config': config} |
|
|
|
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
|
|
|
logger.info(f"{save_path} saving......") |
|
|
|
torch.save(save_state, save_path) |
|
|
|
logger.info(f"{save_path} saved !!!") |
|
|
|
|
|
def get_grad_norm(parameters, norm_type=2): |
|
|
|
if isinstance(parameters, torch.Tensor): |
|
|
|
parameters = [parameters] |
|
|
|
parameters = list(filter(lambda p: p.grad is not None, parameters)) |
|
|
|
norm_type = float(norm_type) |
|
|
|
total_norm = 0 |
|
|
|
for p in parameters: |
|
|
|
param_norm = p.grad.data.norm(norm_type) |
|
|
|
total_norm += param_norm.item() ** norm_type |
|
|
|
total_norm = total_norm ** (1. / norm_type) |
|
|
|
return total_norm |
|
|
|
|
|
def auto_resume_helper(output_dir, logger): |
|
|
|
checkpoints = os.listdir(output_dir) |
|
|
|
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] |
|
|
|
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") |
|
|
|
if len(checkpoints) > 0: |
|
|
|
latest_checkpoint = max([os.path.join(output_dir, d) |
|
for d in checkpoints], key=os.path.getmtime) |
|
|
|
logger.info(f"The latest checkpoint founded: {latest_checkpoint}") |
|
|
|
resume_file = latest_checkpoint |
|
|
|
else: |
|
|
|
resume_file = None |
|
|
|
return resume_file |
|
|
|
|
|
def reduce_tensor(tensor): |
|
|
|
rt = tensor.clone() |
|
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
|
|
rt /= dist.get_world_size() |
|
|
|
return rt |
|
|
|
|
|
def load_pretrained(config, model, logger): |
|
|
|
logger.info( |
|
f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") |
|
|
|
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') |
|
|
|
checkpoint_model = checkpoint['model'] |
|
|
|
if any([True if 'encoder.' in k else |
|
False for k in checkpoint_model.keys()]): |
|
|
|
checkpoint_model = {k.replace( |
|
'encoder.', ''): v for k, v in checkpoint_model.items() |
|
if k.startswith('encoder.')} |
|
|
|
logger.info('Detect pre-trained model, remove [encoder.] prefix.') |
|
|
|
else: |
|
|
|
logger.info( |
|
'Detect non-pre-trained model, pass without doing anything.') |
|
|
|
if config.MODEL.TYPE in ['swin', 'swinv2']: |
|
|
|
logger.info( |
|
">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") |
|
|
|
checkpoint = remap_pretrained_keys_swin( |
|
model, checkpoint_model, logger) |
|
|
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
msg = model.load_state_dict(checkpoint_model, strict=False) |
|
|
|
logger.info(msg) |
|
|
|
del checkpoint |
|
|
|
torch.cuda.empty_cache() |
|
|
|
logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") |
|
|
|
|
|
def remap_pretrained_keys_swin(model, checkpoint_model, logger): |
|
|
|
state_dict = model.state_dict() |
|
|
|
|
|
|
|
all_keys = list(checkpoint_model.keys()) |
|
|
|
for key in all_keys: |
|
|
|
if "relative_position_bias_table" in key: |
|
|
|
logger.info(f"Key: {key}") |
|
|
|
rel_position_bias_table_pretrained = checkpoint_model[key] |
|
|
|
rel_position_bias_table_current = state_dict[key] |
|
|
|
L1, nH1 = rel_position_bias_table_pretrained.size() |
|
|
|
L2, nH2 = rel_position_bias_table_current.size() |
|
|
|
if nH1 != nH2: |
|
logger.info(f"Error in loading {key}, passing......") |
|
|
|
else: |
|
|
|
if L1 != L2: |
|
|
|
logger.info( |
|
f"{key}: Interpolate " + |
|
"relative_position_bias_table using geo.") |
|
|
|
src_size = int(L1 ** 0.5) |
|
|
|
dst_size = int(L2 ** 0.5) |
|
|
|
def geometric_progression(a, r, n): |
|
return a * (1.0 - r ** n) / (1.0 - r) |
|
|
|
left, right = 1.01, 1.5 |
|
|
|
while right - left > 1e-6: |
|
|
|
q = (left + right) / 2.0 |
|
|
|
gp = geometric_progression(1, q, src_size // 2) |
|
|
|
if gp > dst_size // 2: |
|
|
|
right = q |
|
|
|
else: |
|
|
|
left = q |
|
|
|
|
|
|
|
|
|
dis = [] |
|
|
|
cur = 1 |
|
|
|
for i in range(src_size // 2): |
|
|
|
dis.append(cur) |
|
|
|
cur += q ** (i + 1) |
|
|
|
r_ids = [-_ for _ in reversed(dis)] |
|
|
|
x = r_ids + [0] + dis |
|
|
|
y = r_ids + [0] + dis |
|
|
|
t = dst_size // 2.0 |
|
|
|
dx = np.arange(-t, t + 0.1, 1.0) |
|
|
|
dy = np.arange(-t, t + 0.1, 1.0) |
|
|
|
logger.info("Original positions = %s" % str(x)) |
|
|
|
logger.info("Target positions = %s" % str(dx)) |
|
|
|
all_rel_pos_bias = [] |
|
|
|
for i in range(nH1): |
|
|
|
z = rel_position_bias_table_pretrained[:, i].view( |
|
src_size, src_size).float().numpy() |
|
|
|
f_cubic = interpolate.interp2d(x, y, z, kind='cubic') |
|
|
|
all_rel_pos_bias_host = \ |
|
torch.Tensor(f_cubic(dx, dy) |
|
).contiguous().view(-1, 1) |
|
|
|
all_rel_pos_bias.append( |
|
all_rel_pos_bias_host.to( |
|
rel_position_bias_table_pretrained.device)) |
|
|
|
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) |
|
|
|
checkpoint_model[key] = new_rel_pos_bias |
|
|
|
|
|
relative_position_index_keys = [ |
|
k for k in checkpoint_model.keys() if "relative_position_index" in k] |
|
|
|
for k in relative_position_index_keys: |
|
|
|
del checkpoint_model[k] |
|
|
|
|
|
relative_coords_table_keys = [ |
|
k for k in checkpoint_model.keys() if "relative_coords_table" in k] |
|
|
|
for k in relative_coords_table_keys: |
|
|
|
del checkpoint_model[k] |
|
|
|
|
|
attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] |
|
|
|
for k in attn_mask_keys: |
|
|
|
del checkpoint_model[k] |
|
|
|
return checkpoint_model |
|
|
|
|
|
def remap_pretrained_keys_vit(model, checkpoint_model, logger): |
|
|
|
|
|
if getattr(model, 'use_rel_pos_bias', False) and \ |
|
"rel_pos_bias.relative_position_bias_table" in checkpoint_model: |
|
|
|
logger.info( |
|
"Expand the shared relative position " + |
|
"embedding to each transformer block.") |
|
|
|
num_layers = model.get_num_layers() |
|
|
|
rel_pos_bias = \ |
|
checkpoint_model["rel_pos_bias.relative_position_bias_table"] |
|
|
|
for i in range(num_layers): |
|
|
|
checkpoint_model["blocks.%d.attn.relative_position_bias_table" % |
|
i] = rel_pos_bias.clone() |
|
|
|
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") |
|
|
|
|
|
|
|
all_keys = list(checkpoint_model.keys()) |
|
|
|
for key in all_keys: |
|
|
|
if "relative_position_index" in key: |
|
|
|
checkpoint_model.pop(key) |
|
|
|
if "relative_position_bias_table" in key: |
|
|
|
rel_pos_bias = checkpoint_model[key] |
|
|
|
src_num_pos, num_attn_heads = rel_pos_bias.size() |
|
|
|
dst_num_pos, _ = model.state_dict()[key].size() |
|
|
|
dst_patch_shape = model.patch_embed.patch_shape |
|
|
|
if dst_patch_shape[0] != dst_patch_shape[1]: |
|
|
|
raise NotImplementedError() |
|
|
|
num_extra_tokens = dst_num_pos - \ |
|
(dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) |
|
|
|
src_size = int((src_num_pos - num_extra_tokens) ** 0.5) |
|
|
|
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) |
|
|
|
if src_size != dst_size: |
|
|
|
logger.info("Position interpolate for " + |
|
"%s from %dx%d to %dx%d" % ( |
|
key, |
|
src_size, |
|
src_size, |
|
dst_size, |
|
dst_size)) |
|
|
|
extra_tokens = rel_pos_bias[-num_extra_tokens:, :] |
|
|
|
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] |
|
|
|
def geometric_progression(a, r, n): |
|
|
|
return a * (1.0 - r ** n) / (1.0 - r) |
|
|
|
left, right = 1.01, 1.5 |
|
|
|
while right - left > 1e-6: |
|
|
|
q = (left + right) / 2.0 |
|
|
|
gp = geometric_progression(1, q, src_size // 2) |
|
|
|
if gp > dst_size // 2: |
|
|
|
right = q |
|
|
|
else: |
|
|
|
left = q |
|
|
|
|
|
|
|
|
|
dis = [] |
|
|
|
cur = 1 |
|
|
|
for i in range(src_size // 2): |
|
|
|
dis.append(cur) |
|
|
|
cur += q ** (i + 1) |
|
|
|
r_ids = [-_ for _ in reversed(dis)] |
|
|
|
x = r_ids + [0] + dis |
|
|
|
y = r_ids + [0] + dis |
|
|
|
t = dst_size // 2.0 |
|
|
|
dx = np.arange(-t, t + 0.1, 1.0) |
|
|
|
dy = np.arange(-t, t + 0.1, 1.0) |
|
|
|
logger.info("Original positions = %s" % str(x)) |
|
|
|
logger.info("Target positions = %s" % str(dx)) |
|
|
|
all_rel_pos_bias = [] |
|
|
|
for i in range(num_attn_heads): |
|
|
|
z = rel_pos_bias[:, i].view( |
|
src_size, src_size).float().numpy() |
|
|
|
f = interpolate.interp2d(x, y, z, kind='cubic') |
|
|
|
all_rel_pos_bias_host = \ |
|
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1) |
|
|
|
all_rel_pos_bias.append( |
|
all_rel_pos_bias_host.to(rel_pos_bias.device)) |
|
|
|
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) |
|
|
|
new_rel_pos_bias = torch.cat( |
|
(rel_pos_bias, extra_tokens), dim=0) |
|
|
|
checkpoint_model[key] = new_rel_pos_bias |
|
|
|
return checkpoint_model |
|
|