kittendev commited on
Commit
42c5722
·
verified ·
1 Parent(s): 7ecd09f

Update core/networks.py

Browse files
Files changed (1) hide show
  1. core/networks.py +37 -37
core/networks.py CHANGED
@@ -35,44 +35,44 @@ def group_norm(features):
35
 
36
  class Backbone(nn.Module, ABC_Model):
37
  def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
38
- super().__init__()
39
-
40
- self.mode = mode
41
-
42
- if self.mode == 'fix':
43
- self.norm_fn = FixedBatchNorm
44
- else:
45
- self.norm_fn = nn.BatchNorm2d
46
-
47
- if 'resnet' in model_name:
48
- self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
49
- batch_norm_fn=self.norm_fn)
50
-
51
- state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
52
- state_dict.pop('fc.weight')
53
- state_dict.pop('fc.bias')
54
-
55
- self.model.load_state_dict(state_dict)
56
- else:
57
- if segmentation:
58
- dilation, dilated = 4, True
59
  else:
60
- dilation, dilated = 2, False
61
-
62
- self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
63
- norm_layer=self.norm_fn)
64
-
65
- del self.model.avgpool
66
- del self.model.fc
67
-
68
- self.stage1 = nn.Sequential(self.model.conv1,
69
- self.model.bn1,
70
- self.model.relu,
71
- self.model.maxpool)
72
- self.stage2 = nn.Sequential(self.model.layer1)
73
- self.stage3 = nn.Sequential(self.model.layer2)
74
- self.stage4 = nn.Sequential(self.model.layer3)
75
- self.stage5 = nn.Sequential(self.model.layer4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  class Classifier(Backbone):
78
  def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
 
35
 
36
  class Backbone(nn.Module, ABC_Model):
37
  def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+
42
+ if self.mode == 'fix':
43
+ self.norm_fn = FixedBatchNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  else:
45
+ self.norm_fn = nn.BatchNorm2d
46
+
47
+ if 'resnet' in model_name:
48
+ self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
49
+ batch_norm_fn=self.norm_fn)
50
+
51
+ state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
52
+ state_dict.pop('fc.weight')
53
+ state_dict.pop('fc.bias')
54
+
55
+ self.model.load_state_dict(state_dict)
56
+ else:
57
+ if segmentation:
58
+ dilation, dilated = 4, True
59
+ else:
60
+ dilation, dilated = 2, False
61
+
62
+ self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
63
+ norm_layer=self.norm_fn)
64
+
65
+ del self.model.avgpool
66
+ del self.model.fc
67
+
68
+ self.stage1 = nn.Sequential(self.model.conv1,
69
+ self.model.bn1,
70
+ self.model.relu,
71
+ self.model.maxpool)
72
+ self.stage2 = nn.Sequential(self.model.layer1)
73
+ self.stage3 = nn.Sequential(self.model.layer2)
74
+ self.stage4 = nn.Sequential(self.model.layer3)
75
+ self.stage5 = nn.Sequential(self.model.layer4)
76
 
77
  class Classifier(Backbone):
78
  def __init__(self, model_name, state_path, num_classes=20, mode='fix'):