Spaces:
No application file
No application file
File size: 3,418 Bytes
15fa80a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from CLIP.clip import clip
from CLIP.clip import model
import torch
def topk_overlap_loss(gt, pred, K=2, metric='l1'):
idx = torch.argsort(gt, descending=True)
# print(idx)
idx = idx[:K]
pred_TopK_1 = pred.gather(-1,idx)
gt_Topk_1 = gt.gather(-1,idx)
idx_pred = torch.argsort(pred, descending=True)
idx_pred = idx_pred[:K]
try:
gt_TopK_2 = gt.gather(-1, idx_pred)
except Exception as e:
print(e)
print(gt.shape)
print(idx_pred.shape)
pred_TopK_2 = pred.gather(-1, idx_pred)
gt_Topk_1_normed = torch.nn.functional.softmax(gt_Topk_1, dim=-1)
pred_TopK_1_normed = torch.nn.functional.softmax(pred_TopK_1, dim=-1)
gt_TopK_2_normed = torch.nn.functional.softmax(gt_TopK_2, dim=-1)
pred_TopK_2_normed = torch.nn.functional.softmax(pred_TopK_2, dim=-1)
def kl(a,b):
return torch.nn.functional.kl_div(a.log(), b, reduction="batchmean")
def jsd(a,b):
loss = kl(a,b) + kl(b,a)
loss /= 2
return loss
if metric == 'l1':
loss = torch.abs((pred_TopK_1 - gt_Topk_1)) + torch.abs(gt_TopK_2 - pred_TopK_2)
loss = loss/(2*K)
elif metric == "l2":
loss = torch.norm(pred_TopK_1 - gt_Topk_1, p=2) + torch.norm(gt_TopK_2 - pred_TopK_2, p=2)
loss = loss/(2*K)
elif metric == "kl-full":
loss = kl(gt,pred)
elif metric == "jsd-full":
loss = jsd(gt,pred)
elif metric == "kl-topk":
loss = kl(gt_Topk_1_normed,pred_TopK_1_normed) + kl(gt_TopK_2_normed,pred_TopK_2_normed)
loss /=2
elif metric == "jsd-topk":
loss = jsd(gt_Topk_1_normed, pred_TopK_1_normed) + jsd(gt_TopK_2_normed, pred_TopK_2_normed)
loss /= 2
return loss
def topk_overlap_loss_batch(gt,pred,K=2,metric='l1'):
idx = torch.argsort(gt,dim=1,descending=True)
# print(idx)
idx = idx[:,:K]
pred_TopK_1 = pred.gather(1,idx)
gt_Topk_1 = gt.gather(1,idx)
idx_pred = torch.argsort(pred,dim=1,descending=True)
idx_pred = idx_pred[:,:K]
try:
gt_TopK_2 = gt.gather(1, idx_pred)
except Exception as e:
print(e)
print(gt.shape)
print(idx_pred.shape)
pred_TopK_2 = pred.gather(1, idx_pred)
gt_Topk_1_normed = torch.nn.functional.softmax(gt_Topk_1,dim=-1)
pred_TopK_1_normed = torch.nn.functional.softmax(pred_TopK_1,dim=-1)
gt_TopK_2_normed = torch.nn.functional.softmax(gt_TopK_2,dim=-1)
pred_TopK_2_normed = torch.nn.functional.softmax(pred_TopK_2,dim=-1)
def kl(a,b):
return torch.nn.functional.kl_div(a.log(), b, reduction="batchmean")
def jsd(a,b):
loss = kl(a,b) + kl(b,a)
loss /= 2
return loss
if metric == 'l1':
loss = torch.abs((pred_TopK_1 - gt_Topk_1)) + torch.abs(gt_TopK_2 - pred_TopK_2)
loss = loss/(2*K)
elif metric == "l2":
loss = torch.norm(pred_TopK_1 - gt_Topk_1, p=2) + torch.norm(gt_TopK_2 - pred_TopK_2, p=2)
loss = loss/(2*K)
elif metric == "kl-full":
loss = kl(gt,pred)
elif metric == "jsd-full":
loss = jsd(gt,pred)
elif metric == "kl-topk":
loss = kl(gt_Topk_1_normed,pred_TopK_1_normed) + kl(gt_TopK_2_normed,pred_TopK_2_normed)
loss /=2
elif metric == "jsd-topk":
loss = jsd(gt_Topk_1_normed, pred_TopK_1_normed) + jsd(gt_TopK_2_normed, pred_TopK_2_normed)
loss /= 2
return loss
|