Upload model
Browse files- modeling_solider.py +9 -5
modeling_solider.py
CHANGED
@@ -1234,7 +1234,9 @@ class SwinTransformer(BaseModule):
|
|
1234 |
convert_weights=False,
|
1235 |
frozen_stages=-1,
|
1236 |
init_cfg=None,
|
1237 |
-
|
|
|
|
|
1238 |
):
|
1239 |
self.convert_weights = convert_weights
|
1240 |
self.frozen_stages = frozen_stages
|
@@ -1341,6 +1343,7 @@ class SwinTransformer(BaseModule):
|
|
1341 |
|
1342 |
# semantic embedding
|
1343 |
self.semantic_weight = semantic_weight
|
|
|
1344 |
if self.semantic_weight >= 0:
|
1345 |
self.semantic_embed_w = ModuleList()
|
1346 |
self.semantic_embed_b = ModuleList()
|
@@ -1350,10 +1353,11 @@ class SwinTransformer(BaseModule):
|
|
1350 |
semantic_embed_w = nn.Linear(2, self.num_features[i + 1])
|
1351 |
semantic_embed_b = nn.Linear(2, self.num_features[i + 1])
|
1352 |
# TODO: Test with semantic embed unfreeze
|
1353 |
-
|
1354 |
-
param
|
1355 |
-
|
1356 |
-
param
|
|
|
1357 |
trunc_normal_init(semantic_embed_w, std=0.02, bias=0.0)
|
1358 |
trunc_normal_init(semantic_embed_b, std=0.02, bias=0.0)
|
1359 |
self.semantic_embed_w.append(semantic_embed_w)
|
|
|
1234 |
convert_weights=False,
|
1235 |
frozen_stages=-1,
|
1236 |
init_cfg=None,
|
1237 |
+
# NOTE: This is my modification based on SOLIDER
|
1238 |
+
semantic_weight=0.5,
|
1239 |
+
freeze_semantic_embedding=False,
|
1240 |
):
|
1241 |
self.convert_weights = convert_weights
|
1242 |
self.frozen_stages = frozen_stages
|
|
|
1343 |
|
1344 |
# semantic embedding
|
1345 |
self.semantic_weight = semantic_weight
|
1346 |
+
self.freeze_semantic_embedding = freeze_semantic_embedding
|
1347 |
if self.semantic_weight >= 0:
|
1348 |
self.semantic_embed_w = ModuleList()
|
1349 |
self.semantic_embed_b = ModuleList()
|
|
|
1353 |
semantic_embed_w = nn.Linear(2, self.num_features[i + 1])
|
1354 |
semantic_embed_b = nn.Linear(2, self.num_features[i + 1])
|
1355 |
# TODO: Test with semantic embed unfreeze
|
1356 |
+
if self.freeze_semantic_embedding:
|
1357 |
+
for param in semantic_embed_w.parameters():
|
1358 |
+
param.requires_grad = False
|
1359 |
+
for param in semantic_embed_b.parameters():
|
1360 |
+
param.requires_grad = False
|
1361 |
trunc_normal_init(semantic_embed_w, std=0.02, bias=0.0)
|
1362 |
trunc_normal_init(semantic_embed_b, std=0.02, bias=0.0)
|
1363 |
self.semantic_embed_w.append(semantic_embed_w)
|