Update seagull/model/layer.py
Browse files- seagull/model/layer.py +5 -5
seagull/model/layer.py
CHANGED
@@ -77,8 +77,8 @@ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
|
|
77 |
return mask_feat, global_mask
|
78 |
|
79 |
def forward(self, feats, masks, cropped_img):
|
80 |
-
|
81 |
-
|
82 |
num_imgs = len(masks)
|
83 |
|
84 |
for idx in range(num_imgs):
|
@@ -108,16 +108,16 @@ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
|
|
108 |
mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
|
109 |
|
110 |
query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
|
111 |
-
|
112 |
|
113 |
cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
|
114 |
global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
|
115 |
global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
|
116 |
pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
|
117 |
|
118 |
-
|
119 |
|
120 |
-
return
|
121 |
|
122 |
class MaskPooling(nn.Module):
|
123 |
def __init__(self):
|
|
|
77 |
return mask_feat, global_mask
|
78 |
|
79 |
def forward(self, feats, masks, cropped_img):
|
80 |
+
global_features_list = []
|
81 |
+
local_features_list = []
|
82 |
num_imgs = len(masks)
|
83 |
|
84 |
for idx in range(num_imgs):
|
|
|
108 |
mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
|
109 |
|
110 |
query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
|
111 |
+
global_features_list.append(query_feat) # global
|
112 |
|
113 |
cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
|
114 |
global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
|
115 |
global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
|
116 |
pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
|
117 |
|
118 |
+
local_features_list.append(pos_feat) #(imgs_num, 1, q, 4096) # local
|
119 |
|
120 |
+
return global_features_list, local_features_list
|
121 |
|
122 |
class MaskPooling(nn.Module):
|
123 |
def __init__(self):
|