p1atdev commited on
Commit
0a6ee48
·
verified ·
1 Parent(s): 03dbc2c

Upload modeling_siglip.py

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +57 -0
modeling_siglip.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig
7
+ from transformers.utils import ModelOutput
8
+
9
+
10
+ @dataclass
11
+ class SiglipForImageClassifierOutput(ModelOutput):
12
+ loss: torch.FloatTensor | None = None
13
+ logits: torch.FloatTensor | None = None
14
+ pooler_output: torch.FloatTensor | None = None
15
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
16
+ attentions: tuple[torch.FloatTensor, ...] | None = None
17
+
18
+
19
+ class SiglipForImageClassification(SiglipPreTrainedModel):
20
+ config_class = SiglipVisionConfig
21
+ main_input_name = "pixel_values"
22
+
23
+ def __init__(
24
+ self,
25
+ config,
26
+ ):
27
+ super().__init__(config)
28
+
29
+ self.num_labels = config.num_labels
30
+ self.siglip = SiglipVisionModel(config)
31
+
32
+ # Classifier head
33
+ self.classifier = (
34
+ nn.Linear(config.hidden_size, config.num_labels)
35
+ if config.num_labels > 0
36
+ else nn.Identity()
37
+ )
38
+
39
+ # Initialize weights and apply final processing
40
+ self.post_init()
41
+
42
+ def forward(
43
+ self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None
44
+ ):
45
+ outputs = self.siglip(pixel_values)
46
+ pooler_output = outputs.pooler_output
47
+ logits = self.classifier(pooler_output)
48
+
49
+ loss = None
50
+
51
+ return SiglipForImageClassifierOutput(
52
+ loss=loss,
53
+ logits=logits,
54
+ pooler_output=outputs.pooler_output,
55
+ hidden_states=outputs.hidden_states,
56
+ attentions=outputs.attentions,
57
+ )