Spaces:
Sleeping
Sleeping
File size: 1,685 Bytes
46fdf2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
import torch.nn as nn
class CSRA(nn.Module): # one basic block
def __init__(self, input_dim, num_classes, T, lam):
super(CSRA, self).__init__()
self.T = T # temperature
self.lam = lam # Lambda
self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
# x (B d H W)
# normalize classifier
# score (B C HxW)
score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
if self.T == 99: # max-pooling
att_logit = torch.max(score, dim=2)[0]
else:
score_soft = self.softmax(score * self.T)
# https://github.com/Kevinz-code/CSRA/issues/5
att_logit = torch.sum(score * score_soft, dim=2)
return base_logit + self.lam * att_logit
class MHA(nn.Module): # multi-head attention
temp_settings = { # softmax temperature settings
1: [1],
2: [1, 99],
4: [1, 2, 4, 99],
6: [1, 2, 3, 4, 5, 99],
8: [1, 2, 3, 4, 5, 6, 7, 99]
}
def __init__(self, num_heads, lam, input_dim, num_classes):
super(MHA, self).__init__()
self.temp_list = self.temp_settings[num_heads]
self.multi_head = nn.ModuleList([
CSRA(input_dim, num_classes, self.temp_list[i], lam)
for i in range(num_heads)
])
def forward(self, x):
logit = 0.
for head in self.multi_head:
logit += head(x)
return logit
|