tuandunghcmut commited on
Commit
97f555f
·
verified ·
1 Parent(s): 9408aea

Upload model

Browse files
Files changed (1) hide show
  1. modeling_solider.py +9 -5
modeling_solider.py CHANGED
@@ -1234,7 +1234,9 @@ class SwinTransformer(BaseModule):
1234
  convert_weights=False,
1235
  frozen_stages=-1,
1236
  init_cfg=None,
1237
- semantic_weight=0.0,
 
 
1238
  ):
1239
  self.convert_weights = convert_weights
1240
  self.frozen_stages = frozen_stages
@@ -1341,6 +1343,7 @@ class SwinTransformer(BaseModule):
1341
 
1342
  # semantic embedding
1343
  self.semantic_weight = semantic_weight
 
1344
  if self.semantic_weight >= 0:
1345
  self.semantic_embed_w = ModuleList()
1346
  self.semantic_embed_b = ModuleList()
@@ -1350,10 +1353,11 @@ class SwinTransformer(BaseModule):
1350
  semantic_embed_w = nn.Linear(2, self.num_features[i + 1])
1351
  semantic_embed_b = nn.Linear(2, self.num_features[i + 1])
1352
  # TODO: Test with semantic embed unfreeze
1353
- for param in semantic_embed_w.parameters():
1354
- param.requires_grad = False
1355
- for param in semantic_embed_b.parameters():
1356
- param.requires_grad = False
 
1357
  trunc_normal_init(semantic_embed_w, std=0.02, bias=0.0)
1358
  trunc_normal_init(semantic_embed_b, std=0.02, bias=0.0)
1359
  self.semantic_embed_w.append(semantic_embed_w)
 
1234
  convert_weights=False,
1235
  frozen_stages=-1,
1236
  init_cfg=None,
1237
+ # NOTE: This is my modification based on SOLIDER
1238
+ semantic_weight=0.5,
1239
+ freeze_semantic_embedding=False,
1240
  ):
1241
  self.convert_weights = convert_weights
1242
  self.frozen_stages = frozen_stages
 
1343
 
1344
  # semantic embedding
1345
  self.semantic_weight = semantic_weight
1346
+ self.freeze_semantic_embedding = freeze_semantic_embedding
1347
  if self.semantic_weight >= 0:
1348
  self.semantic_embed_w = ModuleList()
1349
  self.semantic_embed_b = ModuleList()
 
1353
  semantic_embed_w = nn.Linear(2, self.num_features[i + 1])
1354
  semantic_embed_b = nn.Linear(2, self.num_features[i + 1])
1355
  # TODO: Test with semantic embed unfreeze
1356
+ if self.freeze_semantic_embedding:
1357
+ for param in semantic_embed_w.parameters():
1358
+ param.requires_grad = False
1359
+ for param in semantic_embed_b.parameters():
1360
+ param.requires_grad = False
1361
  trunc_normal_init(semantic_embed_w, std=0.02, bias=0.0)
1362
  trunc_normal_init(semantic_embed_b, std=0.02, bias=0.0)
1363
  self.semantic_embed_w.append(semantic_embed_w)