jiaxianustc commited on
Commit
e749e85
·
1 Parent(s): 2baa7f7
UltraFlow/losses/losses.py CHANGED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from UltraFlow import layers
5
+
6
+ # margin ranking loss
7
+ class pair_wise_ranking_loss(nn.Module):
8
+ def __init__(self, config):
9
+ super(pair_wise_ranking_loss, self).__init__()
10
+ self.config = config
11
+ self.threshold_filter = nn.Threshold(0.2, 0)
12
+ self.score_predict = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, 1)
13
+
14
+ def ranking_loss(self, z_A, z_B, relation):
15
+ """
16
+ loss for a given set of pixels:
17
+ z_A: predicted absolute depth for pixels A
18
+ z_B: predicted absolute depth for pixels B
19
+ relation: -1, 0, 1
20
+ """
21
+ pred_depth = z_A - z_B
22
+ log_loss = torch.mean(torch.log(1 + torch.exp(-relation[relation != 0] * pred_depth[relation != 0])))
23
+ return log_loss
24
+
25
+ @torch.no_grad()
26
+ def get_rank_relation(self, y_A, y_B):
27
+ pred_depth = y_A - y_B
28
+ pred_depth[self.threshold_filter(pred_depth.abs()) == 0] = 0
29
+
30
+ return pred_depth.sign()
31
+
32
+ def forward(self, output_embedding, target):
33
+ batch_repeat_num = len(output_embedding)
34
+ batch_size = batch_repeat_num // 2
35
+
36
+ score_predict = self.score_predict(output_embedding)
37
+ x_A, y_A, x_B, y_B = score_predict[:batch_size], target[:batch_size], score_predict[batch_size:], target[batch_size:]
38
+
39
+ relation = self.get_rank_relation(y_A, y_B)
40
+
41
+ ranking_loss = self.ranking_loss(x_A, x_B, relation)
42
+
43
+ relation_pred = self.get_rank_relation(x_A, x_B)
44
+
45
+ return ranking_loss, relation.squeeze(), relation_pred.squeeze()
46
+
47
+ # binary cross entropy loss
48
+ class pair_wise_ranking_loss_v2(nn.Module):
49
+ def __init__(self, config):
50
+ super(pair_wise_ranking_loss_v2, self).__init__()
51
+ self.config = config
52
+ self.pretrain_use_assay_description = config.train.pretrain_use_assay_description
53
+ self.loss_fn = nn.CrossEntropyLoss()
54
+ self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
55
+ self.m = nn.Softmax(dim=1)
56
+
57
+ @torch.no_grad()
58
+ def get_rank_relation(self, y_A, y_B):
59
+ # y_A: [batch, 1]
60
+ # target_relation: 0: <=, 1: >
61
+ target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
62
+ target_relation[(y_A - y_B) > 0.0] = 1
63
+
64
+ return target_relation.squeeze()
65
+
66
+ def forward(self, output_embedding, target, assay_des):
67
+ batch_repeat_num = len(output_embedding)
68
+ batch_size = batch_repeat_num // 2
69
+ x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
70
+ output_embedding[batch_size:], target[batch_size:]
71
+
72
+ relation = self.get_rank_relation(y_A, y_B)
73
+
74
+ if self.pretrain_use_assay_description:
75
+ assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ]
76
+ agg_A = x_A + assay_A
77
+ agg_B = x_B + assay_B
78
+ relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1))
79
+ else:
80
+ relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1))
81
+
82
+ ranking_loss = self.loss_fn(relation_pred, relation)
83
+
84
+ _, y_pred = self.m(relation_pred).max(dim=1)
85
+
86
+ return ranking_loss, relation.squeeze(), y_pred
87
+
88
+ # binary cross entropy loss
89
+ class pairwise_BCE_loss(nn.Module):
90
+ def __init__(self, config):
91
+ super(pairwise_BCE_loss, self).__init__()
92
+ self.config = config
93
+ self.pretrain_use_assay_description = config.train.pretrain_use_assay_description
94
+ self.loss_fn = nn.CrossEntropyLoss(reduce=False)
95
+ if config.model.readout.startswith('multi_head') and config.model.attn_merge == 'concat':
96
+ self.relation_mlp = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1) * 2, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
97
+ else:
98
+ self.relation_mlp = layers.FC(config.model.inter_out_dim * 4, [config.model.inter_out_dim * 2, config.model.inter_out_dim], config.model.dropout, 2)
99
+ self.m = nn.Softmax(dim=1)
100
+
101
+ @torch.no_grad()
102
+ def get_rank_relation(self, y_A, y_B):
103
+ # y_A: [batch, 1]
104
+ # target_relation: 0: <=, 1: >
105
+ target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
106
+ target_relation[(y_A - y_B) > 0.0] = 1
107
+
108
+ return target_relation.squeeze()
109
+
110
+ def forward(self, output_embedding, target, assay_des):
111
+ batch_repeat_num = len(output_embedding)
112
+ batch_size = batch_repeat_num // 2
113
+ x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
114
+ output_embedding[batch_size:], target[batch_size:]
115
+
116
+ relation = self.get_rank_relation(y_A, y_B)
117
+
118
+ if self.pretrain_use_assay_description:
119
+ assay_A, assay_B = assay_des[:batch_size], assay_des[batch_size: ]
120
+ agg_A = x_A + assay_A
121
+ agg_B = x_B + assay_B
122
+ relation_pred = self.relation_mlp(torch.cat([agg_A, agg_B], dim=1))
123
+ else:
124
+ relation_pred = self.relation_mlp(torch.cat([x_A,x_B], dim=1))
125
+
126
+ ranking_loss = self.loss_fn(relation_pred, relation)
127
+
128
+ _, y_pred = self.m(relation_pred).max(dim=1)
129
+
130
+ return ranking_loss, relation.squeeze(), y_pred
workdir/gradio/checkpointbest_valid_1.ckp CHANGED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6a16d1c3bf867f08d55242deb52272b481da4bacbd6867f5141bc3325a325203
3
- size 11621767