Caleb Spradlin
initial commit
ab687e7
raw
history blame
18.4 kB
from functools import partial
from torch import optim as optim
import os
import torch
import torch.distributed as dist
import numpy as np
from scipy import interpolate
def build_optimizer(config, model, is_pretrain=False, logger=None):
"""
Build optimizer, set weight decay of normalization to 0 by default.
AdamW only.
"""
logger.info('>>>>>>>>>> Build Optimizer')
skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
if hasattr(model, 'no_weight_decay_keywords'):
skip_keywords = model.no_weight_decay_keywords()
if is_pretrain:
parameters = get_pretrain_param_groups(model, skip, skip_keywords)
else:
depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \
else config.MODEL.SWINV2.DEPTHS
num_layers = sum(depths)
get_layer_func = partial(get_swin_layer,
num_layers=num_layers + 2,
depths=depths)
scales = list(config.TRAIN.LAYER_DECAY ** i for i in
reversed(range(num_layers + 2)))
parameters = get_finetune_param_groups(model,
config.TRAIN.BASE_LR,
config.TRAIN.WEIGHT_DECAY,
get_layer_func,
scales,
skip,
skip_keywords)
optimizer = None
optimizer = optim.AdamW(parameters,
eps=config.TRAIN.OPTIMIZER.EPS,
betas=config.TRAIN.OPTIMIZER.BETAS,
lr=config.TRAIN.BASE_LR,
weight_decay=config.TRAIN.WEIGHT_DECAY)
logger.info(optimizer)
return optimizer
def set_weight_decay(model, skip_list=(), skip_keywords=()):
"""
Args:
model (_type_): _description_
skip_list (tuple, optional): _description_. Defaults to ().
skip_keywords (tuple, optional): _description_. Defaults to ().
Returns:
_type_: _description_
"""
has_decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") \
or (name in skip_list) or \
check_keywords_in_name(name, skip_keywords):
no_decay.append(param)
else:
has_decay.append(param)
return [{'params': has_decay},
{'params': no_decay, 'weight_decay': 0.}]
def check_keywords_in_name(name, keywords=()):
isin = False
for keyword in keywords:
if keyword in name:
isin = True
return isin
def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()):
has_decay = []
no_decay = []
has_decay_name = []
no_decay_name = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias") or \
(name in skip_list) or \
check_keywords_in_name(name, skip_keywords):
no_decay.append(param)
no_decay_name.append(name)
else:
has_decay.append(param)
has_decay_name.append(name)
return [{'params': has_decay},
{'params': no_decay, 'weight_decay': 0.}]
def get_swin_layer(name, num_layers, depths):
if name in ("mask_token"):
return 0
elif name.startswith("patch_embed"):
return 0
elif name.startswith("layers"):
layer_id = int(name.split('.')[1])
block_id = name.split('.')[3]
if block_id == 'reduction' or block_id == 'norm':
return sum(depths[:layer_id + 1])
layer_id = sum(depths[:layer_id]) + int(block_id)
return layer_id + 1
else:
return num_layers - 1
def get_finetune_param_groups(model,
lr,
weight_decay,
get_layer_func,
scales,
skip_list=(),
skip_keywords=()):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias") \
or (name in skip_list) or \
check_keywords_in_name(name, skip_keywords):
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_layer_func is not None:
layer_id = get_layer_func(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if scales is not None:
scale = scales[layer_id]
else:
scale = 1.
parameter_group_names[group_name] = {
"group_name": group_name,
"weight_decay": this_weight_decay,
"params": [],
"lr": lr * scale,
"lr_scale": scale,
}
parameter_group_vars[group_name] = {
"group_name": group_name,
"weight_decay": this_weight_decay,
"params": [],
"lr": lr * scale,
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
return list(parameter_group_vars.values())
def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
# re-map keys due to name change (only for loading provided models)
rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
for k in rpe_mlp_keys:
checkpoint['model'][k.replace(
'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint \
and 'lr_scheduler' in checkpoint \
and 'scaler' in checkpoint \
and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
scaler.load_state_dict(checkpoint['scaler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
logger.info(
f"=> loaded successfully '{config.MODEL.RESUME}' " +
f"(epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
else:
max_accuracy = 0.0
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def save_checkpoint(config, epoch, model, max_accuracy,
optimizer, lr_scheduler, scaler, logger):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'scaler': scaler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir, logger):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d)
for d in checkpoints], key=os.path.getmtime)
logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def load_pretrained(config, model, logger):
logger.info(
f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........")
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
checkpoint_model = checkpoint['model']
if any([True if 'encoder.' in k else
False for k in checkpoint_model.keys()]):
checkpoint_model = {k.replace(
'encoder.', ''): v for k, v in checkpoint_model.items()
if k.startswith('encoder.')}
logger.info('Detect pre-trained model, remove [encoder.] prefix.')
else:
logger.info(
'Detect non-pre-trained model, pass without doing anything.')
if config.MODEL.TYPE in ['swin', 'swinv2']:
logger.info(
">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
checkpoint = remap_pretrained_keys_swin(
model, checkpoint_model, logger)
else:
raise NotImplementedError
msg = model.load_state_dict(checkpoint_model, strict=False)
logger.info(msg)
del checkpoint
torch.cuda.empty_cache()
logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'")
def remap_pretrained_keys_swin(model, checkpoint_model, logger):
state_dict = model.state_dict()
# Geometric interpolation when pre-trained patch size mismatch
# with fine-tuned patch size
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_bias_table" in key:
logger.info(f"Key: {key}")
rel_position_bias_table_pretrained = checkpoint_model[key]
rel_position_bias_table_current = state_dict[key]
L1, nH1 = rel_position_bias_table_pretrained.size()
L2, nH2 = rel_position_bias_table_current.size()
if nH1 != nH2:
logger.info(f"Error in loading {key}, passing......")
else:
if L1 != L2:
logger.info(
f"{key}: Interpolate " +
"relative_position_bias_table using geo.")
src_size = int(L1 ** 0.5)
dst_size = int(L2 ** 0.5)
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
logger.info("Original positions = %s" % str(x))
logger.info("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(nH1):
z = rel_position_bias_table_pretrained[:, i].view(
src_size, src_size).float().numpy()
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias_host = \
torch.Tensor(f_cubic(dx, dy)
).contiguous().view(-1, 1)
all_rel_pos_bias.append(
all_rel_pos_bias_host.to(
rel_position_bias_table_pretrained.device))
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
checkpoint_model[key] = new_rel_pos_bias
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in checkpoint_model.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
del checkpoint_model[k]
# delete relative_coords_table since we always re-init it
relative_coords_table_keys = [
k for k in checkpoint_model.keys() if "relative_coords_table" in k]
for k in relative_coords_table_keys:
del checkpoint_model[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del checkpoint_model[k]
return checkpoint_model
def remap_pretrained_keys_vit(model, checkpoint_model, logger):
# Duplicate shared rel_pos_bias to each layer
if getattr(model, 'use_rel_pos_bias', False) and \
"rel_pos_bias.relative_position_bias_table" in checkpoint_model:
logger.info(
"Expand the shared relative position " +
"embedding to each transformer block.")
num_layers = model.get_num_layers()
rel_pos_bias = \
checkpoint_model["rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
checkpoint_model["blocks.%d.attn.relative_position_bias_table" %
i] = rel_pos_bias.clone()
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
# Geometric interpolation when pre-trained patch
# size mismatch with fine-tuned patch size
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_index" in key:
checkpoint_model.pop(key)
if "relative_position_bias_table" in key:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - \
(dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
logger.info("Position interpolate for " +
"%s from %dx%d to %dx%d" % (
key,
src_size,
src_size,
dst_size,
dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
logger.info("Original positions = %s" % str(x))
logger.info("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(
src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias_host = \
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1)
all_rel_pos_bias.append(
all_rel_pos_bias_host.to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat(
(rel_pos_bias, extra_tokens), dim=0)
checkpoint_model[key] = new_rel_pos_bias
return checkpoint_model