import sys sys.path.insert(0, '/data1/PycharmProjects/FGVC11/pytorch-image-models-main') # sys.path.insert(0, '/data1/PycharmProjects/FGVC11/submision/pytorch-image-models-main') # sys.path.insert(0, '/data1/PycharmProjects/FGVC11/submision/pytorch-image-models') import timm import torch from torch import nn import torch.nn.functional as F from timm.layers import LayerNorm2d, LayerNorm,NormMlpClassifierHead, ClassifierHead from timm.models.convnext import ConvNeXtStage import numpy as np LARGE_CP = '' # path to pretrain weights class expert(nn.Module): def __init__(self, model_arch, num_classes,pretrain=True) -> None: super().__init__() self.model = ConvNeXtStage( in_chs = 768, out_chs= 1536, kernel_size=7, stride=2, dilation=(1, 1), depth=3, drop_path_rates=[0.0,0.0,0.0], ls_init_value=1e-6, conv_mlp=False, conv_bias=True, use_grn=False, act_layer='gelu', norm_layer = LayerNorm2d, norm_layer_cl = LayerNorm, ) self.cls_head = NormMlpClassifierHead( in_features=1536, num_classes= num_classes, hidden_size=1536, pool_type='avg',#max drop_rate=0.0, norm_layer=LayerNorm2d, act_layer='gelu', ) if model_arch == 'convnext_large_mlp': checkpoints = LARGE_CP assert False, 'pretrain weight not found' print('use pretrain weight:', checkpoints) state_dict = torch.load(checkpoints) for key in list(state_dict.keys()): if key.startswith('module.'): new_key = key[7:] state_dict[new_key] = state_dict[key] del state_dict[key] for key in list(state_dict.keys()): if not 'stages.3.' in key: del state_dict[key] if key.startswith('stages.3.'): new_key = key[9:] state_dict[new_key] = state_dict[key] del state_dict[key] self.model.load_state_dict(state_dict,strict = True) del state_dict def forward(self, out_stage3): out = self.model(out_stage3) out = self.cls_head(out) return out class Moe(nn.Module): def __init__(self, model_arch, num_classes,mask,pretrain=True) -> None: super().__init__() if pretrain: self.backbone = timm.create_model(model_arch, num_classes=0, pretrained=False,out_stage3=True) if model_arch == 'convnextv2_base.fcmae_ft_in22k_in1k_384': checkpoints = BASE_CP elif model_arch == 'convnext_large_mlp': checkpoints = LARGE_CP else: assert False, 'pretrain weight not found' print('use pretrain weight:', checkpoints) state_dict = torch.load(checkpoints) self.backbone.load_state_dict(state_dict, strict=False) del state_dict self.head = NormMlpClassifierHead( in_features=1536, num_classes= num_classes, hidden_size=1536, pool_type='avg',#max drop_rate=0.0, norm_layer=LayerNorm2d, act_layer='gelu', ) self.expert_venomous = expert(model_arch, num_classes) self.expert_not_venomous = expert(model_arch, num_classes) self.venomous_head = nn.Linear(768+1536,1,bias=False) torch.nn.init.xavier_uniform_(self.venomous_head.weight) self.venomous_mask = mask self.not_venomous_mask = torch.ones_like(mask)-mask def forward(self, x): out4,out3 = self.backbone(x) feat = torch.cat([F.adaptive_max_pool2d(out3,1).flatten(1), F.adaptive_max_pool2d(out4,1).flatten(1)],dim=-1) is_venomous = self.venomous_head(feat) alpha= torch.sigmoid(is_venomous) venomous = self.expert_venomous(out3)*self.venomous_mask.to(x.device) not_venomous =self.expert_not_venomous(out3)*self.not_venomous_mask.to(x.device) y_hat = self.head(out4) # expert_pred = venomous * alpha + not_venomous*(1-alpha) expert_pred = venomous + not_venomous final_pred = y_hat+expert_pred return y_hat,expert_pred,is_venomous,final_pred class SeesawLossWithLogits(nn.Module): """ This is unofficial implementation for Seesaw loss, which is proposed in the techinical report for LVIS workshop at ECCV 2020. For more detail, please refer https://arxiv.org/pdf/2008.10032.pdf. Args: class_counts: The list which has number of samples for each class. Should have same length as num_classes. p: Scale parameter which adjust the strength of panishment. Set to 0.8 as a default by following the original paper. """ def __init__(self, class_counts: np.array,num_classes, p: float = 0.8): super().__init__() class_counts = torch.FloatTensor(class_counts) conditions = class_counts[:, None] > class_counts[None, :] trues = (class_counts[None, :] / class_counts[:, None]) ** p falses = torch.ones(len(class_counts), len(class_counts)) self.s = torch.where(conditions, trues, falses) self.num_classes = num_classes self.eps = 1.0e-6 def forward(self, logits, targets): targets = nn.functional.one_hot(targets,num_classes=self.num_classes).float().to(targets.device) self.s = self.s.to(targets.device) max_element, _ = logits.max(axis=-1) logits = logits - max_element[:, None] # to prevent overflow numerator = torch.exp(logits) denominator = ( (1 - targets)[:, None, :] * self.s[None, :, :] * torch.exp(logits)[:, None, :]).sum(axis=-1) \ + torch.exp(logits) sigma = numerator / (denominator + self.eps) loss = (- targets * torch.log(sigma + self.eps)).sum(-1) return loss.mean() class all_loss(nn.Module): def __init__(self, class_counts: np.array, num_classes): super().__init__() self.main_loss = SeesawLossWithLogits(class_counts,num_classes) self.venomous_loss = SeesawLossWithLogits(class_counts,num_classes) self.final_pred_loss = SeesawLossWithLogits(class_counts,num_classes) # self.venomous_loss = nn.CrossEntropyLoss() # self.alpha_loss = nn.BCEWithLogitsLoss() # self.final_pred_loss = nn.CrossEntropyLoss() def forward(self,y_hat,expert_pred,alpha,final_pred,targets,is_venomous): loss1 = self.main_loss(y_hat,targets) loss2 = self.venomous_loss(expert_pred,targets) # loss3 = self.alpha_loss(alpha,is_venomous.unsqueeze(1)) loss4 = self.final_pred_loss(final_pred,targets) return (loss1+loss2+loss4)/3