kittendev commited on
Commit
7ecd09f
1 Parent(s): 6dbf305

Update core/networks.py

Browse files
Files changed (1) hide show
  1. core/networks.py +33 -33
core/networks.py CHANGED
@@ -34,45 +34,45 @@ def group_norm(features):
34
  #######################################################################
35
 
36
  class Backbone(nn.Module, ABC_Model):
37
- def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
38
- super().__init__()
39
 
40
- self.mode = mode
41
 
42
- if self.mode == 'fix':
43
- self.norm_fn = FixedBatchNorm
44
- else:
45
- self.norm_fn = nn.BatchNorm2d
46
 
47
- if 'resnet' in model_name:
48
- self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
49
- batch_norm_fn=self.norm_fn)
50
 
51
- state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
52
- state_dict.pop('fc.weight')
53
- state_dict.pop('fc.bias')
54
 
55
- self.model.load_state_dict(state_dict)
 
 
 
56
  else:
57
- if segmentation:
58
- dilation, dilated = 4, True
59
- else:
60
- dilation, dilated = 2, False
61
-
62
- self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
63
- norm_layer=self.norm_fn)
64
-
65
- del self.model.avgpool
66
- del self.model.fc
67
-
68
- self.stage1 = nn.Sequential(self.model.conv1,
69
- self.model.bn1,
70
- self.model.relu,
71
- self.model.maxpool)
72
- self.stage2 = nn.Sequential(self.model.layer1)
73
- self.stage3 = nn.Sequential(self.model.layer2)
74
- self.stage4 = nn.Sequential(self.model.layer3)
75
- self.stage5 = nn.Sequential(self.model.layer4)
76
 
77
  class Classifier(Backbone):
78
  def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
 
34
  #######################################################################
35
 
36
  class Backbone(nn.Module, ABC_Model):
37
+ def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
38
+ super().__init__()
39
 
40
+ self.mode = mode
41
 
42
+ if self.mode == 'fix':
43
+ self.norm_fn = FixedBatchNorm
44
+ else:
45
+ self.norm_fn = nn.BatchNorm2d
46
 
47
+ if 'resnet' in model_name:
48
+ self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
49
+ batch_norm_fn=self.norm_fn)
50
 
51
+ state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
52
+ state_dict.pop('fc.weight')
53
+ state_dict.pop('fc.bias')
54
 
55
+ self.model.load_state_dict(state_dict)
56
+ else:
57
+ if segmentation:
58
+ dilation, dilated = 4, True
59
  else:
60
+ dilation, dilated = 2, False
61
+
62
+ self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
63
+ norm_layer=self.norm_fn)
64
+
65
+ del self.model.avgpool
66
+ del self.model.fc
67
+
68
+ self.stage1 = nn.Sequential(self.model.conv1,
69
+ self.model.bn1,
70
+ self.model.relu,
71
+ self.model.maxpool)
72
+ self.stage2 = nn.Sequential(self.model.layer1)
73
+ self.stage3 = nn.Sequential(self.model.layer2)
74
+ self.stage4 = nn.Sequential(self.model.layer3)
75
+ self.stage5 = nn.Sequential(self.model.layer4)
 
 
 
76
 
77
  class Classifier(Backbone):
78
  def __init__(self, model_name, state_path, num_classes=20, mode='fix'):