|
import numpy as np |
|
import shutil |
|
import torch |
|
import os |
|
import io |
|
import logging |
|
from collections import defaultdict |
|
|
|
|
|
from torch.nn import BatchNorm2d, SyncBatchNorm |
|
|
|
def param_group_no_wd(model): |
|
pgroup_no_wd = [] |
|
names_no_wd = [] |
|
pgroup_normal = [] |
|
|
|
type2num = defaultdict(lambda : 0) |
|
for name,m in model.named_modules(): |
|
if isinstance(m, torch.nn.Conv2d): |
|
if m.bias is not None: |
|
pgroup_no_wd.append(m.bias) |
|
names_no_wd.append(name+'.bias') |
|
type2num[m.__class__.__name__+'.bias'] += 1 |
|
elif isinstance(m, torch.nn.Linear): |
|
if m.bias is not None: |
|
pgroup_no_wd.append(m.bias) |
|
names_no_wd.append(name+'.bias') |
|
type2num[m.__class__.__name__+'.bias'] += 1 |
|
elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, SyncBatchNorm): |
|
if m.weight is not None: |
|
pgroup_no_wd.append(m.weight) |
|
names_no_wd.append(name+'.weight') |
|
type2num[m.__class__.__name__+'.weight'] += 1 |
|
if m.bias is not None: |
|
pgroup_no_wd.append(m.bias) |
|
names_no_wd.append(name+'.bias') |
|
type2num[m.__class__.__name__+'.bias'] += 1 |
|
|
|
for name,p in model.named_parameters(): |
|
if not name in names_no_wd: |
|
pgroup_normal.append(p) |
|
|
|
return [{'params': pgroup_normal}, {'params': pgroup_no_wd, 'weight_decay': 0.0}], type2num |
|
|
|
def param_group_fc(model): |
|
logits_w_id = id(model.module.logits.weight) |
|
fc_group = [] |
|
normal_group = [] |
|
for p in model.parameters(): |
|
if id(p) == logits_w_id: |
|
fc_group.append(p) |
|
else: |
|
normal_group.append(p) |
|
param_group = [{'params': fc_group}, {'params': normal_group}] |
|
|
|
return param_group |
|
|
|
def param_group_multitask(model): |
|
backbone_group = [] |
|
neck_group = [] |
|
decoder_group = [] |
|
other_group = [] |
|
for name, p in model.named_parameters(): |
|
if 'module.backbone_module' in name: |
|
backbone_group.append(p) |
|
elif 'module.neck_module' in name: |
|
neck_group.append(p) |
|
elif 'module.decoder_module' in name: |
|
decoder_group.append(p) |
|
else: |
|
other_group.append(p) |
|
|
|
if len(other_group) > 0: |
|
param_group = [{'params': backbone_group}, {'params': neck_group}, \ |
|
{'params': decoder_group}, {'params', other_group}] |
|
else: |
|
param_group = [{'params': backbone_group}, {'params': neck_group}, \ |
|
{'params': decoder_group}] |
|
return param_group |
|
|