WSSS_ResNet50 / core /abc_modules.py
kittendev's picture
Upload 176 files
c20a1af verified
import math
import torch
import torch.nn as nn
from abc import ABC
class ABC_Model(ABC):
def global_average_pooling_2d(self, x, keepdims=False):
x = torch.mean(x.view(x.size(0), x.size(1), -1), -1)
if keepdims:
x = x.view(x.size(0), x.size(1), 1, 1)
return x
def initialize(self, modules):
for m in modules:
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def get_parameter_groups(self, print_fn=print):
groups = ([], [], [], [])
for name, value in self.named_parameters():
# pretrained weights
if 'model' in name:
if 'weight' in name:
# print_fn(f'pretrained weights : {name}')
groups[0].append(value)
else:
# print_fn(f'pretrained bias : {name}')
groups[1].append(value)
# scracthed weights
else:
if 'weight' in name:
if print_fn is not None:
print_fn(f'scratched weights : {name}')
groups[2].append(value)
else:
if print_fn is not None:
print_fn(f'scratched bias : {name}')
groups[3].append(value)
return groups