ForkedHulk2 / core /make_param_group.py
tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
import numpy as np
import shutil
import torch
import os
import io
import logging
from collections import defaultdict
from torch.nn import BatchNorm2d, SyncBatchNorm
def param_group_no_wd(model):
pgroup_no_wd = []
names_no_wd = []
pgroup_normal = []
type2num = defaultdict(lambda : 0)
for name,m in model.named_modules():
if isinstance(m, torch.nn.Conv2d):
if m.bias is not None:
pgroup_no_wd.append(m.bias)
names_no_wd.append(name+'.bias')
type2num[m.__class__.__name__+'.bias'] += 1
elif isinstance(m, torch.nn.Linear):
if m.bias is not None:
pgroup_no_wd.append(m.bias)
names_no_wd.append(name+'.bias')
type2num[m.__class__.__name__+'.bias'] += 1
elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, SyncBatchNorm):
if m.weight is not None:
pgroup_no_wd.append(m.weight)
names_no_wd.append(name+'.weight')
type2num[m.__class__.__name__+'.weight'] += 1
if m.bias is not None:
pgroup_no_wd.append(m.bias)
names_no_wd.append(name+'.bias')
type2num[m.__class__.__name__+'.bias'] += 1
for name,p in model.named_parameters():
if not name in names_no_wd:
pgroup_normal.append(p)
return [{'params': pgroup_normal}, {'params': pgroup_no_wd, 'weight_decay': 0.0}], type2num
def param_group_fc(model):
logits_w_id = id(model.module.logits.weight)
fc_group = []
normal_group = []
for p in model.parameters():
if id(p) == logits_w_id:
fc_group.append(p)
else:
normal_group.append(p)
param_group = [{'params': fc_group}, {'params': normal_group}]
return param_group
def param_group_multitask(model):
backbone_group = []
neck_group = []
decoder_group = []
other_group = []
for name, p in model.named_parameters():
if 'module.backbone_module' in name:
backbone_group.append(p)
elif 'module.neck_module' in name:
neck_group.append(p)
elif 'module.decoder_module' in name:
decoder_group.append(p)
else:
other_group.append(p)
if len(other_group) > 0:
param_group = [{'params': backbone_group}, {'params': neck_group}, \
{'params': decoder_group}, {'params', other_group}]
else:
param_group = [{'params': backbone_group}, {'params': neck_group}, \
{'params': decoder_group}]
return param_group