marcusinthesky commited on
Commit
c46712e
·
1 Parent(s): 737c4bf

Upload model

Browse files
Files changed (3) hide show
  1. config.json +39 -0
  2. modelling.py +236 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/sam-vit-base",
3
+ "architectures": [
4
+ "SamVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModel": "modelling.SamVisionModel"
9
+ },
10
+ "dropout": 0.0,
11
+ "global_attn_indexes": [
12
+ 2,
13
+ 5,
14
+ 8,
15
+ 11
16
+ ],
17
+ "hidden_act": "gelu",
18
+ "hidden_size": 768,
19
+ "image_size": 1024,
20
+ "initializer_factor": 1.0,
21
+ "initializer_range": 1e-10,
22
+ "intermediate_size": 6144,
23
+ "layer_norm_eps": 1e-06,
24
+ "mlp_dim": 3072,
25
+ "mlp_ratio": 4.0,
26
+ "num_attention_heads": 12,
27
+ "num_channels": 3,
28
+ "num_hidden_layers": 12,
29
+ "num_pos_feats": 128,
30
+ "output_channels": 256,
31
+ "patch_size": 16,
32
+ "projection_dim": 512,
33
+ "qkv_bias": true,
34
+ "torch_dtype": "float32",
35
+ "transformers_version": "4.32.0.dev0",
36
+ "use_abs_pos": true,
37
+ "use_rel_pos": true,
38
+ "window_size": 14
39
+ }
modelling.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/12_modelling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['VTDEConfig', 'VTDEModel', 'SamVisionPreTrainedModel', 'SamVisionModel']
5
+
6
+ # %% ../notebooks/12_modelling.ipynb 1
7
+ from transformers.models.clip.modeling_clip import CLIPOutput, clip_loss
8
+ from typing import Optional, Tuple, Union
9
+ from transformers import PreTrainedModel, VisionTextDualEncoderModel
10
+ import torch
11
+ from transformers import VisionTextDualEncoderConfig
12
+
13
+ class VTDEConfig(VisionTextDualEncoderConfig):
14
+ model_type = "vtde"
15
+
16
+ def __init__(self, projection_dim=512, logit_scale_init_value=2.6592,
17
+ text_pooling_mode='mean',
18
+ vision_pooling_mode='max',
19
+ **kwargs):
20
+ """
21
+ pooling_mode in ['mean', 'max', 'cls']
22
+ https://arxiv.org/pdf/2210.09996.pdf
23
+ https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56
24
+ also
25
+ https://arxiv.org/pdf/2301.07836.pdf
26
+ """
27
+ self.text_pooling_mode = text_pooling_mode
28
+ self.vision_pooling_mode = vision_pooling_mode
29
+ super().__init__(projection_dim, logit_scale_init_value, **kwargs)
30
+
31
+ VTDEConfig.register_for_auto_class()
32
+
33
+
34
+ class VTDEModel(VisionTextDualEncoderModel):
35
+ config_class = VTDEConfig
36
+ base_model_prefix = "vtde"
37
+
38
+ def __init__(
39
+ self,
40
+ config: Optional[VTDEConfig] = None,
41
+ vision_model: Optional[PreTrainedModel] = None,
42
+ text_model: Optional[PreTrainedModel] = None,
43
+ ):
44
+ # You can customize the constructor if needed
45
+ super().__init__(config, vision_model, text_model)
46
+ self.text_pooling_mode = config.text_pooling_mode
47
+ self.vision_pooling_mode = config.vision_pooling_mode
48
+
49
+ def get_text_features(
50
+ self,
51
+ input_ids=None,
52
+ attention_mask=None,
53
+ position_ids=None,
54
+ token_type_ids=None,
55
+ output_attentions=None,
56
+ output_hidden_states=None,
57
+ return_dict=None,
58
+ ):
59
+ text_outputs = self.text_model(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ output_attentions=output_attentions,
65
+ output_hidden_states=output_hidden_states,
66
+ return_dict=return_dict,
67
+ )
68
+ if self.text_pooling_mode == 'cls':
69
+ pooled_output = text_outputs[1]
70
+ elif self.text_pooling_mode == 'mean':
71
+ pooled_output = torch.mean(text_outputs[0], dim=1)
72
+ elif self.text_pooling_mode == 'max':
73
+ pooled_output = torch.max(text_outputs[0], dim=1)[0]
74
+ elif self.text_pooling_mode == 'norm':
75
+ """we select the patch with the largest norm"""
76
+ last_hidden_states = text_outputs[0]
77
+ patch_norms = torch.norm(last_hidden_states[:, 1:, :], dim=-1)
78
+ max_norm_idx = torch.argmax(patch_norms, dim=1)
79
+ pooled_output = last_hidden_states[:, max_norm_idx, :][:, 0, :]
80
+ else:
81
+ "We want to raise the name of the pooling mode"
82
+ raise NotImplementedError
83
+
84
+ text_features = self.text_projection(pooled_output)
85
+
86
+ return text_features
87
+
88
+ def get_image_features(
89
+ self,
90
+ pixel_values=None,
91
+ output_attentions=None,
92
+ output_hidden_states=None,
93
+ return_dict=None,
94
+ ):
95
+ vision_outputs = self.vision_model(
96
+ pixel_values=pixel_values,
97
+ output_attentions=output_attentions,
98
+ output_hidden_states=output_hidden_states,
99
+ return_dict=return_dict,
100
+ )
101
+
102
+ if self.vision_pooling_mode == 'cls':
103
+ pooled_output = vision_outputs[1]
104
+ elif self.vision_pooling_mode == 'mean':
105
+ pooled_output = torch.mean(vision_outputs[0], dim=1)
106
+ elif self.vision_pooling_mode == 'max':
107
+ pooled_output = torch.max(vision_outputs[0], dim=1)[0]
108
+ elif self.vision_pooling_mode == 'norm':
109
+ """we select the patch with the largest norm"""
110
+ last_hidden_states = vision_outputs[0]
111
+ patch_norms = torch.norm(last_hidden_states[:, 1:, :], dim=-1)
112
+ max_norm_idx = torch.argmax(patch_norms, dim=1)
113
+ pooled_output = last_hidden_states[:, max_norm_idx, :][:, 0, :]
114
+ else:
115
+ raise NotImplementedError
116
+
117
+ image_features = self.visual_projection(pooled_output)
118
+
119
+ return image_features
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: Optional[torch.LongTensor] = None,
124
+ pixel_values: Optional[torch.FloatTensor] = None,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ position_ids: Optional[torch.LongTensor] = None,
127
+ return_loss: Optional[bool] = None,
128
+ token_type_ids: Optional[torch.LongTensor] = None,
129
+ output_attentions: Optional[bool] = None,
130
+ output_hidden_states: Optional[bool] = None,
131
+ return_dict: Optional[bool] = None,
132
+ ) -> Union[Tuple[torch.Tensor], CLIPOutput]:
133
+
134
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
135
+
136
+ image_embeds = self.get_image_features(
137
+ pixel_values=pixel_values,
138
+ output_attentions=output_attentions,
139
+ output_hidden_states=output_hidden_states,
140
+ return_dict=return_dict,
141
+ )
142
+
143
+ text_embeds = self.get_text_features(
144
+ input_ids=input_ids,
145
+ attention_mask=attention_mask,
146
+ position_ids=position_ids,
147
+ output_attentions=output_attentions,
148
+ output_hidden_states=output_hidden_states,
149
+ return_dict=return_dict,
150
+ )
151
+
152
+ # normalized features
153
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
154
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
155
+
156
+ # cosine similarity as logits
157
+ logit_scale = self.logit_scale.exp()
158
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
159
+ logits_per_image = logits_per_text.T
160
+
161
+ loss = None
162
+ if return_loss:
163
+ loss = clip_loss(logits_per_text)
164
+
165
+ if not return_dict:
166
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_embeds, image_embeds)
167
+ return ((loss,) + output) if loss is not None else output
168
+
169
+ return CLIPOutput(
170
+ loss=loss,
171
+ logits_per_image=logits_per_image,
172
+ logits_per_text=logits_per_text,
173
+ text_embeds=text_embeds,
174
+ image_embeds=image_embeds,
175
+ text_model_output=text_embeds,
176
+ vision_model_output=image_embeds,
177
+ )
178
+
179
+
180
+ VTDEModel.register_for_auto_class("AutoModel")
181
+ VTDEModel.register_for_auto_class("AutoModelForZeroShotImageClassification")
182
+
183
+ # %% ../notebooks/12_modelling.ipynb 2
184
+ # we want to create a vision-text encoder model for SAM
185
+ from transformers import PreTrainedModel
186
+ from transformers.models.sam.modeling_sam import SamPositionalEmbedding, SamVisionEncoder, SamVisionEncoderOutput
187
+ from transformers.models.sam.configuration_sam import SamVisionConfig
188
+ from torch import nn
189
+
190
+ class SamVisionPreTrainedModel(PreTrainedModel):
191
+ config_class = SamVisionConfig
192
+ base_model_prefix = "sam_vision_encoder"
193
+ main_input_name = "pixel_values"
194
+
195
+ def _init_weights(self, module):
196
+ std = self.config.initializer_range
197
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
198
+ module.weight.data.normal_(mean=0.0, std=std)
199
+ if module.bias is not None:
200
+ module.bias.data.zero_()
201
+ elif isinstance(module, nn.Embedding):
202
+ module.weight.data.normal_(mean=0.0, std=std)
203
+ if module.padding_idx is not None:
204
+ module.weight.data[module.padding_idx].zero_()
205
+
206
+ class SamVisionModel(SamVisionPreTrainedModel):
207
+
208
+ def __init__(self, config):
209
+ super().__init__(config)
210
+ self.shared_image_embedding = SamPositionalEmbedding(config)
211
+ self.vision_encoder = SamVisionEncoder(config)
212
+
213
+ def forward(
214
+ self,
215
+ pixel_values=None,
216
+ attention_mask=None,
217
+ output_attentions=None,
218
+ output_hidden_states=None,
219
+ return_dict=None,
220
+ ) -> SamVisionEncoderOutput:
221
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
222
+
223
+ image_embeddings = self.shared_image_embedding(pixel_values)
224
+ vision_encoder_outputs = self.vision_encoder(
225
+ image_embeddings,
226
+ attention_mask=attention_mask,
227
+ output_attentions=output_attentions,
228
+ output_hidden_states=output_hidden_states,
229
+ return_dict=return_dict,
230
+ )
231
+
232
+ return vision_encoder_outputs
233
+
234
+ SamVisionModel.register_for_auto_class("AutoModel")
235
+ # SamVisionConfig.register_for_auto_class()
236
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cebd0dc11b2b662674af8eab00a2dd992c81af72e70f4827d45de717822e935
3
+ size 358741525