dependencies = ['torch'] from modeling_sagvit import SAGViTClassifier import torch def SAGViT(pretrained=False, **kwargs): """ SAG-ViT model endpoint. Args: pretrained (bool): If True, loads pretrained weights. **kwargs: Additional arguments for the model. Returns: model (nn.Module): The SAG-ViT model as proposed in the paper: SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers. https://doi.org/10.48550/arXiv.2411.09420 """ model = SAGViTClassifier(**kwargs) if pretrained: checkpoint = '' state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True) model.load_state_dict(state_dict) return model