import math import torch import torch.nn as nn from abc import ABC class ABC_Model(ABC): def global_average_pooling_2d(self, x, keepdims=False): x = torch.mean(x.view(x.size(0), x.size(1), -1), -1) if keepdims: x = x.view(x.size(0), x.size(1), 1, 1) return x def initialize(self, modules): for m in modules: if isinstance(m, nn.Conv2d): # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # m.weight.data.normal_(0, math.sqrt(2. / n)) torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def get_parameter_groups(self, print_fn=print): groups = ([], [], [], []) for name, value in self.named_parameters(): # pretrained weights if 'model' in name: if 'weight' in name: # print_fn(f'pretrained weights : {name}') groups[0].append(value) else: # print_fn(f'pretrained bias : {name}') groups[1].append(value) # scracthed weights else: if 'weight' in name: if print_fn is not None: print_fn(f'scratched weights : {name}') groups[2].append(value) else: if print_fn is not None: print_fn(f'scratched bias : {name}') groups[3].append(value) return groups