File size: 772 Bytes
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
dependencies = ['torch']

from sag_vit_model import SAGViTClassifier  
import torch

def SAGViT(pretrained=False, **kwargs):
    """
    SAG-ViT model endpoint.
    Args:
        pretrained (bool): If True, loads pretrained weights.
        **kwargs: Additional arguments for the model.
    Returns:
        model (nn.Module): The SAG-ViT model as proposed in the
        paper: SAG-ViT: A Scale-Aware, High-Fidelity Patching 
        Approach with Graph Attention for Vision Transformers.
        https://doi.org/10.48550/arXiv.2411.09420
    """
    model = SAGViTClassifier(**kwargs)
    if pretrained:
        checkpoint = ''
        state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
        model.load_state_dict(state_dict)
    return model