import torch import torch.nn as nn from models import image import torch.nn.functional as F # loss function def KL(alpha, c): if torch.cuda.is_available(): beta = torch.ones((1, c)).cuda() else: beta = torch.ones((1, c)) S_alpha = torch.sum(alpha, dim=1, keepdim=True) S_beta = torch.sum(beta, dim=1, keepdim=True) lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True) lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta) dg0 = torch.digamma(S_alpha) dg1 = torch.digamma(alpha) kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni return kl def ce_loss(p, alpha, c, global_step, annealing_step): S = torch.sum(alpha, dim=1, keepdim=True) E = alpha - 1 label = p A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True) annealing_coef = min(1, global_step / annealing_step) alp = E * (1 - label) + 1 B = annealing_coef * KL(alp, c) return torch.mean((A + B)) class TMC(nn.Module): def __init__(self, args): super(TMC, self).__init__() self.args = args self.rgbenc = image.ImageEncoder(args) self.specenc = image.RawNet(args) spec_last_size = args.img_hidden_sz * 1 rgb_last_size = args.img_hidden_sz * args.num_image_embeds self.spec_depth = nn.ModuleList() self.clf_rgb = nn.ModuleList() for hidden in args.hidden: self.spec_depth.append(nn.Linear(spec_last_size, hidden)) self.spec_depth.append(nn.ReLU()) self.spec_depth.append(nn.Dropout(args.dropout)) spec_last_size = hidden self.spec_depth.append(nn.Linear(spec_last_size, args.n_classes)) for hidden in args.hidden: self.clf_rgb.append(nn.Linear(rgb_last_size, hidden)) self.clf_rgb.append(nn.ReLU()) self.clf_rgb.append(nn.Dropout(args.dropout)) rgb_last_size = hidden self.clf_rgb.append(nn.Linear(rgb_last_size, args.n_classes)) def DS_Combin_two(self, alpha1, alpha2): # Calculate the merger of two DS evidences alpha = dict() alpha[0], alpha[1] = alpha1, alpha2 b, S, E, u = dict(), dict(), dict(), dict() for v in range(2): S[v] = torch.sum(alpha[v], dim=1, keepdim=True) E[v] = alpha[v] - 1 b[v] = E[v] / (S[v].expand(E[v].shape)) u[v] = self.args.n_classes / S[v] # b^0 @ b^(0+1) bb = torch.bmm(b[0].view(-1, self.args.n_classes, 1), b[1].view(-1, 1, self.args.n_classes)) # b^0 * u^1 uv1_expand = u[1].expand(b[0].shape) bu = torch.mul(b[0], uv1_expand) # b^1 * u^0 uv_expand = u[0].expand(b[0].shape) ub = torch.mul(b[1], uv_expand) # calculate K bb_sum = torch.sum(bb, dim=(1, 2), out=None) bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1) # bb_diag1 = torch.diag(torch.mm(b[v], torch.transpose(b[v+1], 0, 1))) K = bb_sum - bb_diag # calculate b^a b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - K).view(-1, 1).expand(b[0].shape)) # calculate u^a u_a = torch.mul(u[0], u[1]) / ((1 - K).view(-1, 1).expand(u[0].shape)) # test = torch.sum(b_a, dim = 1, keepdim = True) + u_a #Verify programming errors # calculate new S S_a = self.args.n_classes / u_a # calculate new e_k e_a = torch.mul(b_a, S_a.expand(b_a.shape)) alpha_a = e_a + 1 return alpha_a def forward(self, rgb, spec): spec = self.specenc(spec) spec = torch.flatten(spec, start_dim=1) rgb = self.rgbenc(rgb) rgb = torch.flatten(rgb, start_dim=1) spec_out = spec for layer in self.spec_depth: spec_out = layer(spec_out) rgb_out = rgb for layer in self.clf_rgb: rgb_out = layer(rgb_out) spec_evidence, rgb_evidence = F.softplus(spec_out), F.softplus(rgb_out) spec_alpha, rgb_alpha = spec_evidence+1, rgb_evidence+1 spec_rgb_alpha = self.DS_Combin_two(spec_alpha, rgb_alpha) return spec_alpha, rgb_alpha, spec_rgb_alpha class ETMC(TMC): def __init__(self, args): super(ETMC, self).__init__(args) last_size = args.img_hidden_sz * args.num_image_embeds + args.img_hidden_sz * args.num_image_embeds self.clf = nn.ModuleList() for hidden in args.hidden: self.clf.append(nn.Linear(last_size, hidden)) self.clf.append(nn.ReLU()) self.clf.append(nn.Dropout(args.dropout)) last_size = hidden self.clf.append(nn.Linear(last_size, args.n_classes)) def forward(self, rgb, spec): spec = self.specenc(spec) spec = torch.flatten(spec, start_dim=1) rgb = self.rgbenc(rgb) rgb = torch.flatten(rgb, start_dim=1) spec_out = spec for layer in self.spec_depth: spec_out = layer(spec_out) rgb_out = rgb for layer in self.clf_rgb: rgb_out = layer(rgb_out) pseudo_out = torch.cat([rgb, spec], -1) for layer in self.clf: pseudo_out = layer(pseudo_out) depth_evidence, rgb_evidence, pseudo_evidence = F.softplus(spec_out), F.softplus(rgb_out), F.softplus(pseudo_out) depth_alpha, rgb_alpha, pseudo_alpha = depth_evidence+1, rgb_evidence+1, pseudo_evidence+1 depth_rgb_alpha = self.DS_Combin_two(self.DS_Combin_two(depth_alpha, rgb_alpha), pseudo_alpha) return depth_alpha, rgb_alpha, pseudo_alpha, depth_rgb_alpha