File size: 3,418 Bytes
15fa80a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from CLIP.clip import clip
from CLIP.clip import model
import torch

def topk_overlap_loss(gt, pred, K=2, metric='l1'):
    idx = torch.argsort(gt, descending=True)
    # print(idx)
    idx = idx[:K]
    pred_TopK_1 = pred.gather(-1,idx)
    gt_Topk_1 = gt.gather(-1,idx)

    idx_pred = torch.argsort(pred, descending=True)
    idx_pred = idx_pred[:K]
    try:
        gt_TopK_2 = gt.gather(-1, idx_pred)
    except Exception as e:
        print(e)
        print(gt.shape)
        print(idx_pred.shape)
    pred_TopK_2 = pred.gather(-1, idx_pred)

    gt_Topk_1_normed = torch.nn.functional.softmax(gt_Topk_1, dim=-1)
    pred_TopK_1_normed = torch.nn.functional.softmax(pred_TopK_1, dim=-1)
    gt_TopK_2_normed = torch.nn.functional.softmax(gt_TopK_2, dim=-1)
    pred_TopK_2_normed = torch.nn.functional.softmax(pred_TopK_2, dim=-1)

    def kl(a,b):
        return torch.nn.functional.kl_div(a.log(), b, reduction="batchmean")

    def jsd(a,b):
        loss = kl(a,b) + kl(b,a)
        loss /= 2
        return loss


    if metric == 'l1':
        loss = torch.abs((pred_TopK_1 - gt_Topk_1)) + torch.abs(gt_TopK_2 - pred_TopK_2)
        loss = loss/(2*K)
    elif metric == "l2":
        loss = torch.norm(pred_TopK_1 - gt_Topk_1, p=2) + torch.norm(gt_TopK_2 - pred_TopK_2, p=2)
        loss = loss/(2*K)
    elif metric == "kl-full":
        loss = kl(gt,pred)
    elif metric == "jsd-full":
        loss = jsd(gt,pred)
    elif metric == "kl-topk":
        loss = kl(gt_Topk_1_normed,pred_TopK_1_normed) + kl(gt_TopK_2_normed,pred_TopK_2_normed)
        loss /=2
    elif metric == "jsd-topk":
        loss = jsd(gt_Topk_1_normed, pred_TopK_1_normed) + jsd(gt_TopK_2_normed, pred_TopK_2_normed)
        loss /= 2
    return loss

def topk_overlap_loss_batch(gt,pred,K=2,metric='l1'):
    idx = torch.argsort(gt,dim=1,descending=True)
    # print(idx)
    idx = idx[:,:K]
    pred_TopK_1 = pred.gather(1,idx)
    gt_Topk_1 = gt.gather(1,idx)

    idx_pred = torch.argsort(pred,dim=1,descending=True)
    idx_pred = idx_pred[:,:K]
    try:
        gt_TopK_2 = gt.gather(1, idx_pred)
    except Exception as e:
        print(e)
        print(gt.shape)
        print(idx_pred.shape)
    pred_TopK_2 = pred.gather(1, idx_pred)

    gt_Topk_1_normed = torch.nn.functional.softmax(gt_Topk_1,dim=-1)
    pred_TopK_1_normed = torch.nn.functional.softmax(pred_TopK_1,dim=-1)
    gt_TopK_2_normed = torch.nn.functional.softmax(gt_TopK_2,dim=-1)
    pred_TopK_2_normed = torch.nn.functional.softmax(pred_TopK_2,dim=-1)

    def kl(a,b):
        return torch.nn.functional.kl_div(a.log(), b, reduction="batchmean")

    def jsd(a,b):
        loss = kl(a,b) + kl(b,a)
        loss /= 2
        return loss


    if metric == 'l1':
        loss = torch.abs((pred_TopK_1 - gt_Topk_1)) + torch.abs(gt_TopK_2 - pred_TopK_2)
        loss = loss/(2*K)
    elif metric == "l2":
        loss = torch.norm(pred_TopK_1 - gt_Topk_1, p=2) + torch.norm(gt_TopK_2 - pred_TopK_2, p=2)
        loss = loss/(2*K)
    elif metric == "kl-full":
        loss = kl(gt,pred)
    elif metric == "jsd-full":
        loss = jsd(gt,pred)
    elif metric == "kl-topk":
        loss = kl(gt_Topk_1_normed,pred_TopK_1_normed) + kl(gt_TopK_2_normed,pred_TopK_2_normed)
        loss /=2
    elif metric == "jsd-topk":
        loss = jsd(gt_Topk_1_normed, pred_TopK_1_normed) + jsd(gt_TopK_2_normed, pred_TopK_2_normed)
        loss /= 2
    return loss