Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
from abc import ABC | |
class ABC_Model(ABC): | |
def global_average_pooling_2d(self, x, keepdims=False): | |
x = torch.mean(x.view(x.size(0), x.size(1), -1), -1) | |
if keepdims: | |
x = x.view(x.size(0), x.size(1), 1, 1) | |
return x | |
def initialize(self, modules): | |
for m in modules: | |
if isinstance(m, nn.Conv2d): | |
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
# m.weight.data.normal_(0, math.sqrt(2. / n)) | |
torch.nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def get_parameter_groups(self, print_fn=print): | |
groups = ([], [], [], []) | |
for name, value in self.named_parameters(): | |
# pretrained weights | |
if 'model' in name: | |
if 'weight' in name: | |
# print_fn(f'pretrained weights : {name}') | |
groups[0].append(value) | |
else: | |
# print_fn(f'pretrained bias : {name}') | |
groups[1].append(value) | |
# scracthed weights | |
else: | |
if 'weight' in name: | |
if print_fn is not None: | |
print_fn(f'scratched weights : {name}') | |
groups[2].append(value) | |
else: | |
if print_fn is not None: | |
print_fn(f'scratched bias : {name}') | |
groups[3].append(value) | |
return groups | |