Spaces:
Sleeping
Sleeping
File size: 1,727 Bytes
c20a1af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import math
import torch
import torch.nn as nn
from abc import ABC
class ABC_Model(ABC):
def global_average_pooling_2d(self, x, keepdims=False):
x = torch.mean(x.view(x.size(0), x.size(1), -1), -1)
if keepdims:
x = x.view(x.size(0), x.size(1), 1, 1)
return x
def initialize(self, modules):
for m in modules:
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def get_parameter_groups(self, print_fn=print):
groups = ([], [], [], [])
for name, value in self.named_parameters():
# pretrained weights
if 'model' in name:
if 'weight' in name:
# print_fn(f'pretrained weights : {name}')
groups[0].append(value)
else:
# print_fn(f'pretrained bias : {name}')
groups[1].append(value)
# scracthed weights
else:
if 'weight' in name:
if print_fn is not None:
print_fn(f'scratched weights : {name}')
groups[2].append(value)
else:
if print_fn is not None:
print_fn(f'scratched bias : {name}')
groups[3].append(value)
return groups
|