Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from models import image | |
import torch.nn.functional as F | |
# loss function | |
def KL(alpha, c): | |
if torch.cuda.is_available(): | |
beta = torch.ones((1, c)).cuda() | |
else: | |
beta = torch.ones((1, c)) | |
S_alpha = torch.sum(alpha, dim=1, keepdim=True) | |
S_beta = torch.sum(beta, dim=1, keepdim=True) | |
lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True) | |
lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta) | |
dg0 = torch.digamma(S_alpha) | |
dg1 = torch.digamma(alpha) | |
kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni | |
return kl | |
def ce_loss(p, alpha, c, global_step, annealing_step): | |
S = torch.sum(alpha, dim=1, keepdim=True) | |
E = alpha - 1 | |
label = p | |
A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True) | |
annealing_coef = min(1, global_step / annealing_step) | |
alp = E * (1 - label) + 1 | |
B = annealing_coef * KL(alp, c) | |
return torch.mean((A + B)) | |
class TMC(nn.Module): | |
def __init__(self, args): | |
super(TMC, self).__init__() | |
self.args = args | |
self.rgbenc = image.ImageEncoder(args) | |
self.specenc = image.RawNet(args) | |
spec_last_size = args.img_hidden_sz * 1 | |
rgb_last_size = args.img_hidden_sz * args.num_image_embeds | |
self.spec_depth = nn.ModuleList() | |
self.clf_rgb = nn.ModuleList() | |
for hidden in args.hidden: | |
self.spec_depth.append(nn.Linear(spec_last_size, hidden)) | |
self.spec_depth.append(nn.ReLU()) | |
self.spec_depth.append(nn.Dropout(args.dropout)) | |
spec_last_size = hidden | |
self.spec_depth.append(nn.Linear(spec_last_size, args.n_classes)) | |
for hidden in args.hidden: | |
self.clf_rgb.append(nn.Linear(rgb_last_size, hidden)) | |
self.clf_rgb.append(nn.ReLU()) | |
self.clf_rgb.append(nn.Dropout(args.dropout)) | |
rgb_last_size = hidden | |
self.clf_rgb.append(nn.Linear(rgb_last_size, args.n_classes)) | |
def DS_Combin_two(self, alpha1, alpha2): | |
# Calculate the merger of two DS evidences | |
alpha = dict() | |
alpha[0], alpha[1] = alpha1, alpha2 | |
b, S, E, u = dict(), dict(), dict(), dict() | |
for v in range(2): | |
S[v] = torch.sum(alpha[v], dim=1, keepdim=True) | |
E[v] = alpha[v] - 1 | |
b[v] = E[v] / (S[v].expand(E[v].shape)) | |
u[v] = self.args.n_classes / S[v] | |
# b^0 @ b^(0+1) | |
bb = torch.bmm(b[0].view(-1, self.args.n_classes, 1), b[1].view(-1, 1, self.args.n_classes)) | |
# b^0 * u^1 | |
uv1_expand = u[1].expand(b[0].shape) | |
bu = torch.mul(b[0], uv1_expand) | |
# b^1 * u^0 | |
uv_expand = u[0].expand(b[0].shape) | |
ub = torch.mul(b[1], uv_expand) | |
# calculate K | |
bb_sum = torch.sum(bb, dim=(1, 2), out=None) | |
bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1) | |
# bb_diag1 = torch.diag(torch.mm(b[v], torch.transpose(b[v+1], 0, 1))) | |
K = bb_sum - bb_diag | |
# calculate b^a | |
b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - K).view(-1, 1).expand(b[0].shape)) | |
# calculate u^a | |
u_a = torch.mul(u[0], u[1]) / ((1 - K).view(-1, 1).expand(u[0].shape)) | |
# test = torch.sum(b_a, dim = 1, keepdim = True) + u_a #Verify programming errors | |
# calculate new S | |
S_a = self.args.n_classes / u_a | |
# calculate new e_k | |
e_a = torch.mul(b_a, S_a.expand(b_a.shape)) | |
alpha_a = e_a + 1 | |
return alpha_a | |
def forward(self, rgb, spec): | |
spec = self.specenc(spec) | |
spec = torch.flatten(spec, start_dim=1) | |
rgb = self.rgbenc(rgb) | |
rgb = torch.flatten(rgb, start_dim=1) | |
spec_out = spec | |
for layer in self.spec_depth: | |
spec_out = layer(spec_out) | |
rgb_out = rgb | |
for layer in self.clf_rgb: | |
rgb_out = layer(rgb_out) | |
spec_evidence, rgb_evidence = F.softplus(spec_out), F.softplus(rgb_out) | |
spec_alpha, rgb_alpha = spec_evidence+1, rgb_evidence+1 | |
spec_rgb_alpha = self.DS_Combin_two(spec_alpha, rgb_alpha) | |
return spec_alpha, rgb_alpha, spec_rgb_alpha | |
class ETMC(TMC): | |
def __init__(self, args): | |
super(ETMC, self).__init__(args) | |
last_size = args.img_hidden_sz * args.num_image_embeds + args.img_hidden_sz * args.num_image_embeds | |
self.clf = nn.ModuleList() | |
for hidden in args.hidden: | |
self.clf.append(nn.Linear(last_size, hidden)) | |
self.clf.append(nn.ReLU()) | |
self.clf.append(nn.Dropout(args.dropout)) | |
last_size = hidden | |
self.clf.append(nn.Linear(last_size, args.n_classes)) | |
def forward(self, rgb, spec): | |
spec = self.specenc(spec) | |
spec = torch.flatten(spec, start_dim=1) | |
rgb = self.rgbenc(rgb) | |
rgb = torch.flatten(rgb, start_dim=1) | |
spec_out = spec | |
for layer in self.spec_depth: | |
spec_out = layer(spec_out) | |
rgb_out = rgb | |
for layer in self.clf_rgb: | |
rgb_out = layer(rgb_out) | |
pseudo_out = torch.cat([rgb, spec], -1) | |
for layer in self.clf: | |
pseudo_out = layer(pseudo_out) | |
depth_evidence, rgb_evidence, pseudo_evidence = F.softplus(spec_out), F.softplus(rgb_out), F.softplus(pseudo_out) | |
depth_alpha, rgb_alpha, pseudo_alpha = depth_evidence+1, rgb_evidence+1, pseudo_evidence+1 | |
depth_rgb_alpha = self.DS_Combin_two(self.DS_Combin_two(depth_alpha, rgb_alpha), pseudo_alpha) | |
return depth_alpha, rgb_alpha, pseudo_alpha, depth_rgb_alpha | |