Zevin2023 commited on
Commit
eb97d8c
1 Parent(s): 45c48cf

Update seagull/model/layer.py

Browse files
Files changed (1) hide show
  1. seagull/model/layer.py +5 -5
seagull/model/layer.py CHANGED
@@ -77,8 +77,8 @@ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
77
  return mask_feat, global_mask
78
 
79
  def forward(self, feats, masks, cropped_img):
80
- global_features = []
81
- local_features = []
82
  num_imgs = len(masks)
83
 
84
  for idx in range(num_imgs):
@@ -108,16 +108,16 @@ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
108
  mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
109
 
110
  query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
111
- global_features.append(query_feat) # global
112
 
113
  cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
114
  global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
115
  global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
116
  pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
117
 
118
- local_features.append(pos_feat) #(imgs_num, 1, q, 4096) # local
119
 
120
- return global_features, local_features
121
 
122
  class MaskPooling(nn.Module):
123
  def __init__(self):
 
77
  return mask_feat, global_mask
78
 
79
  def forward(self, feats, masks, cropped_img):
80
+ global_features_list = []
81
+ local_features_list = []
82
  num_imgs = len(masks)
83
 
84
  for idx in range(num_imgs):
 
108
  mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
109
 
110
  query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
111
+ global_features_list.append(query_feat) # global
112
 
113
  cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
114
  global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
115
  global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
116
  pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
117
 
118
+ local_features_list.append(pos_feat) #(imgs_num, 1, q, 4096) # local
119
 
120
+ return global_features_list, local_features_list
121
 
122
  class MaskPooling(nn.Module):
123
  def __init__(self):