Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class CSRA(nn.Module): # one basic block | |
def __init__(self, input_dim, num_classes, T, lam): | |
super(CSRA, self).__init__() | |
self.T = T # temperature | |
self.lam = lam # Lambda | |
self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False) | |
self.softmax = nn.Softmax(dim=2) | |
def forward(self, x): | |
# x (B d H W) | |
# normalize classifier | |
# score (B C HxW) | |
score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1) | |
score = score.flatten(2) | |
base_logit = torch.mean(score, dim=2) | |
if self.T == 99: # max-pooling | |
att_logit = torch.max(score, dim=2)[0] | |
else: | |
score_soft = self.softmax(score * self.T) | |
# https://github.com/Kevinz-code/CSRA/issues/5 | |
att_logit = torch.sum(score * score_soft, dim=2) | |
return base_logit + self.lam * att_logit | |
class MHA(nn.Module): # multi-head attention | |
temp_settings = { # softmax temperature settings | |
1: [1], | |
2: [1, 99], | |
4: [1, 2, 4, 99], | |
6: [1, 2, 3, 4, 5, 99], | |
8: [1, 2, 3, 4, 5, 6, 7, 99] | |
} | |
def __init__(self, num_heads, lam, input_dim, num_classes): | |
super(MHA, self).__init__() | |
self.temp_list = self.temp_settings[num_heads] | |
self.multi_head = nn.ModuleList([ | |
CSRA(input_dim, num_classes, self.temp_list[i], lam) | |
for i in range(num_heads) | |
]) | |
def forward(self, x): | |
logit = 0. | |
for head in self.multi_head: | |
logit += head(x) | |
return logit | |