import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from modelscope.msdatasets import MsDataset class Interpolate(nn.Module): def __init__( self, size=None, scale_factor=None, mode="bilinear", align_corners=False, ): super(Interpolate, self).__init__() self.size = size self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): return F.interpolate( x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, ) class t_EvalNet: def __init__( self, backbone: str, cls_num: int, ori_T: int, imgnet_ver="v1", weight_path="", ): if not hasattr(models, backbone): raise ValueError(f"Unsupported model {backbone}.") self.imgnet_ver = imgnet_ver self.type, self.weight_url, self.input_size = self._model_info(backbone) self.model: torch.nn.Module = eval("models.%s()" % backbone) self.ori_T = ori_T if self.type == "vit": self.hidden_dim = self.model.hidden_dim self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) elif self.type == "swin_transformer": self.hidden_dim = 768 self.cls_num = cls_num self._set_classifier() checkpoint = ( torch.load(weight_path) if torch.cuda.is_available() else torch.load(weight_path, map_location="cpu") ) self.model.load_state_dict(checkpoint["model"], False) self.classifier.load_state_dict(checkpoint["classifier"], False) if torch.cuda.is_available(): self.model = self.model.cuda() self.classifier = self.classifier.cuda() self.model.eval() def _get_backbone(self, backbone_ver, backbone_list): for backbone_info in backbone_list: if backbone_ver == backbone_info["ver"]: return backbone_info raise ValueError("[Backbone not found] Please check if --model is correct!") def _model_info(self, backbone: str): backbone_list = MsDataset.load( "monetjoe/cv_backbones", split=self.imgnet_ver, cache_dir="./__pycache__", ) backbone_info = self._get_backbone(backbone, backbone_list) return ( str(backbone_info["type"]), str(backbone_info["url"]), int(backbone_info["input_size"]), ) def _create_classifier(self): original_T_size = self.ori_T self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1 upsample_module = nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1 nn.ConvTranspose2d( self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) ), nn.ReLU(inplace=True), nn.BatchNorm2d(256), nn.ConvTranspose2d( 256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) ), nn.ReLU(inplace=True), nn.BatchNorm2d(128), nn.ConvTranspose2d( 128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) ), nn.ReLU(inplace=True), nn.BatchNorm2d(64), nn.ConvTranspose2d( 64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) ), nn.ReLU(inplace=True), nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T] Interpolate( size=(1, original_T_size), mode="bilinear", align_corners=False ), # classifier nn.Conv2d(32, 32, kernel_size=(1, 1)), nn.ReLU(inplace=True), nn.BatchNorm2d(32), nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)), ) return upsample_module def _set_classifier(self): #### set custom classifier #### if self.type == "vit" or self.type == "swin_transformer": self.classifier = self._create_classifier() def get_input_size(self): return self.input_size def forward(self, x: torch.Tensor): if torch.cuda.is_available(): x = x.cuda() if self.type == "vit": x = self.model._process_input(x) batch_class_token = self.class_token.expand(x.size(0), -1, -1).cuda() x = torch.cat([batch_class_token, x], dim=1) x = self.model.encoder(x) x = x[:, 1:].permute(0, 2, 1) x = x.unsqueeze(2) x = self.classifier(x).squeeze() # x shape: [bsz, hidden_dim, 1, seq_len] return x elif self.type == "swin_transformer": x = self.model.features(x) # [B, H, W, C] x = x.permute(0, 3, 1, 2) x = self.avgpool(x) # [B, C, 1, W] x = self.classifier(x).squeeze() return x return None