hasibzunair's picture
inital files
46fdf2a
raw
history blame
1.69 kB
import torch
import torch.nn as nn
class CSRA(nn.Module): # one basic block
def __init__(self, input_dim, num_classes, T, lam):
super(CSRA, self).__init__()
self.T = T # temperature
self.lam = lam # Lambda
self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
# x (B d H W)
# normalize classifier
# score (B C HxW)
score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
if self.T == 99: # max-pooling
att_logit = torch.max(score, dim=2)[0]
else:
score_soft = self.softmax(score * self.T)
# https://github.com/Kevinz-code/CSRA/issues/5
att_logit = torch.sum(score * score_soft, dim=2)
return base_logit + self.lam * att_logit
class MHA(nn.Module): # multi-head attention
temp_settings = { # softmax temperature settings
1: [1],
2: [1, 99],
4: [1, 2, 4, 99],
6: [1, 2, 3, 4, 5, 99],
8: [1, 2, 3, 4, 5, 6, 7, 99]
}
def __init__(self, num_heads, lam, input_dim, num_classes):
super(MHA, self).__init__()
self.temp_list = self.temp_settings[num_heads]
self.multi_head = nn.ModuleList([
CSRA(input_dim, num_classes, self.temp_list[i], lam)
for i in range(num_heads)
])
def forward(self, x):
logit = 0.
for head in self.multi_head:
logit += head(x)
return logit