File size: 618 Bytes
aea26c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from .vit import vit
import numpy as np 

def create_backbone(cfg):
    if cfg.MODEL.BACKBONE.TYPE == 'vit':
        return vit(cfg)
    elif cfg.MODEL.BACKBONE.TYPE == 'fast_vit':
        import torch 
        import sys 
        #import models 
        from timm.models import create_model
        
        
        fast_vit = create_model("fastvit_ma36", drop_path_rate=0.2)
        checkpoint = torch.load('./pretrained_vit/fastvit_ma36.pt')
        fast_vit.load_state_dict(checkpoint['state_dict'])
        return fast_vit
        
    else:
        raise NotImplementedError('Backbone type is not implemented')