| from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, SiglipModel |
| import torch |
| import torch.nn as nn |
| from huggingface_hub import hf_hub_download |
|
|
| class ExplainerConfig(PretrainedConfig): |
| model_type = "explainer" |
|
|
| def __init__(self, base_model_name='google/siglip2-giant-opt-patch16-384', |
| hidden_dim=768, giant=True, **kwargs): |
| self.base_model_name = base_model_name |
| self.hidden_dim = hidden_dim |
| self.giant = giant |
| super().__init__(**kwargs) |
|
|
| class SigLIPBBoxRegressor(nn.Module): |
| def __init__(self, siglip_model, hidden_dim=768, giant=True): |
| super().__init__() |
| self.siglip = siglip_model |
|
|
| vision_dim = self.siglip.vision_model.config.hidden_size |
| text_dim = self.siglip.text_model.config.hidden_size |
| if giant: text_dim = 1536 |
|
|
| |
| self.vision_projector = nn.Sequential( |
| nn.Linear(vision_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.1) |
| ) |
| self.text_projector = nn.Sequential( |
| nn.Linear(text_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.1) |
| ) |
| |
| |
| self.fusion_layer = nn.Sequential( |
| nn.Linear(hidden_dim*2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, hidden_dim//2), |
| nn.ReLU(), |
| nn.Dropout(0.1) |
| ) |
| self.topleft_regressor = nn.Sequential( |
| nn.Linear(hidden_dim//2, 256), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, 128), |
| nn.ReLU(), |
| nn.Linear(128, 2), |
| ) |
| self.bottomright_regressor = nn.Sequential( |
| nn.Linear(hidden_dim//2, 256), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, 128), |
| nn.ReLU(), |
| nn.Linear(128, 2), |
| ) |
|
|
| def forward(self, pixel_values, input_ids): |
| with torch.no_grad(): |
| outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True) |
| |
| |
| vision_features = outputs.image_embeds.float() |
| text_features = outputs.text_embeds.float() |
| |
| |
|
|
| vision_proj = self.vision_projector(vision_features) |
| text_proj = self.text_projector(text_features) |
| |
| |
| fused = torch.cat([vision_proj, text_proj], dim=1) |
| fused_features = self.fusion_layer(fused) |
| |
| |
| topleft_pred = self.topleft_regressor(fused_features) |
| bottomright_pred = self.bottomright_regressor(fused_features) |
| |
| return torch.cat([topleft_pred, bottomright_pred], dim=1) |
|
|
| class Explainer(PreTrainedModel): |
| config_class = ExplainerConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.siglip_model = SiglipModel.from_pretrained(config.base_model_name) |
| self.bbox_regressor = SigLIPBBoxRegressor(self.siglip_model) |
| self.processor = AutoProcessor.from_pretrained(config.base_model_name, use_fast=True) |
|
|
| def forward(self, pixel_values=None, input_ids=None): |
| return self.bbox_regressor(pixel_values, input_ids) |
|
|
| def predict(self, image, text, device="cuda"): |
| self.to(device) |
| self.eval() |
| inputs = self.processor( |
| text=text, |
| images=image, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=64 |
| ) |
| pixel_values = inputs["pixel_values"].to(device).half() |
| input_ids = inputs["input_ids"].to(device) |
| with torch.no_grad(): |
| pred_bbox = self.forward(pixel_values, input_ids) |
| return pred_bbox[0].cpu().numpy().tolist() |
|
|
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| config = kwargs.pop("config", None) |
| if config is None: |
| config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
| model = cls(config) |
| |
| checkpoint_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="pytorch_model.bin" |
| ) |
| |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| model.siglip_model.load_state_dict(checkpoint["siglip_model"]) |
| model.bbox_regressor.load_state_dict(checkpoint["bbox_regressor"]) |
| return model |
|
|