Spaces:
Sleeping
Sleeping
Update core/networks.py
Browse files- core/networks.py +37 -37
core/networks.py
CHANGED
@@ -35,44 +35,44 @@ def group_norm(features):
|
|
35 |
|
36 |
class Backbone(nn.Module, ABC_Model):
|
37 |
def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
else:
|
45 |
-
self.norm_fn = nn.BatchNorm2d
|
46 |
-
|
47 |
-
if 'resnet' in model_name:
|
48 |
-
self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
|
49 |
-
batch_norm_fn=self.norm_fn)
|
50 |
-
|
51 |
-
state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
|
52 |
-
state_dict.pop('fc.weight')
|
53 |
-
state_dict.pop('fc.bias')
|
54 |
-
|
55 |
-
self.model.load_state_dict(state_dict)
|
56 |
-
else:
|
57 |
-
if segmentation:
|
58 |
-
dilation, dilated = 4, True
|
59 |
else:
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
class Classifier(Backbone):
|
78 |
def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
|
|
|
35 |
|
36 |
class Backbone(nn.Module, ABC_Model):
|
37 |
def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.mode = mode
|
41 |
+
|
42 |
+
if self.mode == 'fix':
|
43 |
+
self.norm_fn = FixedBatchNorm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
else:
|
45 |
+
self.norm_fn = nn.BatchNorm2d
|
46 |
+
|
47 |
+
if 'resnet' in model_name:
|
48 |
+
self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
|
49 |
+
batch_norm_fn=self.norm_fn)
|
50 |
+
|
51 |
+
state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
|
52 |
+
state_dict.pop('fc.weight')
|
53 |
+
state_dict.pop('fc.bias')
|
54 |
+
|
55 |
+
self.model.load_state_dict(state_dict)
|
56 |
+
else:
|
57 |
+
if segmentation:
|
58 |
+
dilation, dilated = 4, True
|
59 |
+
else:
|
60 |
+
dilation, dilated = 2, False
|
61 |
+
|
62 |
+
self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
|
63 |
+
norm_layer=self.norm_fn)
|
64 |
+
|
65 |
+
del self.model.avgpool
|
66 |
+
del self.model.fc
|
67 |
+
|
68 |
+
self.stage1 = nn.Sequential(self.model.conv1,
|
69 |
+
self.model.bn1,
|
70 |
+
self.model.relu,
|
71 |
+
self.model.maxpool)
|
72 |
+
self.stage2 = nn.Sequential(self.model.layer1)
|
73 |
+
self.stage3 = nn.Sequential(self.model.layer2)
|
74 |
+
self.stage4 = nn.Sequential(self.model.layer3)
|
75 |
+
self.stage5 = nn.Sequential(self.model.layer4)
|
76 |
|
77 |
class Classifier(Backbone):
|
78 |
def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
|