4kasha
commited on
Commit
·
f31ab4f
1
Parent(s):
37d364a
fix
Browse files- aligner.py +4 -4
- app.py +4 -6
- otfuncs.py +3 -3
aligner.py
CHANGED
@@ -30,19 +30,19 @@ class Aligner:
|
|
30 |
self.weight_func = compute_weights_norm
|
31 |
|
32 |
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
|
33 |
-
P, Cost, log, similarity_matrix
|
34 |
print(log.keys())
|
35 |
if torch.is_tensor(P):
|
36 |
P = P.to('cpu').numpy()
|
37 |
loss = log.get('cost', 'NotImplemented')
|
38 |
|
39 |
-
return P, Cost, loss, similarity_matrix
|
40 |
|
41 |
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
|
42 |
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
|
43 |
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
|
44 |
|
45 |
-
C, similarity_matrix
|
46 |
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
|
47 |
|
48 |
if self.ot_type == 'ot':
|
@@ -96,7 +96,7 @@ class Aligner:
|
|
96 |
elif self.ot_type == 'none':
|
97 |
P = 1 - C
|
98 |
|
99 |
-
return P, C, log, similarity_matrix
|
100 |
|
101 |
def convert_to_numpy(self, s1_weights, s2_weights, C):
|
102 |
if torch.is_tensor(s1_weights):
|
|
|
30 |
self.weight_func = compute_weights_norm
|
31 |
|
32 |
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
|
33 |
+
P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
|
34 |
print(log.keys())
|
35 |
if torch.is_tensor(P):
|
36 |
P = P.to('cpu').numpy()
|
37 |
loss = log.get('cost', 'NotImplemented')
|
38 |
|
39 |
+
return P, Cost, loss, similarity_matrix
|
40 |
|
41 |
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
|
42 |
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
|
43 |
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
|
44 |
|
45 |
+
C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
|
46 |
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
|
47 |
|
48 |
if self.ot_type == 'ot':
|
|
|
96 |
elif self.ot_type == 'none':
|
97 |
P = 1 - C
|
98 |
|
99 |
+
return P, C, log, similarity_matrix
|
100 |
|
101 |
def convert_to_numpy(self, s1_weights, s2_weights, C):
|
102 |
if torch.is_tensor(s1_weights):
|
app.py
CHANGED
@@ -8,8 +8,6 @@ from nltk.tokenize import word_tokenize
|
|
8 |
from transformers import AutoModel, AutoTokenizer
|
9 |
|
10 |
from aligner import Aligner
|
11 |
-
|
12 |
-
# from utils import align_matrix_heatmap, plot_align_matrix_heatmap
|
13 |
from plotools import (
|
14 |
plot_align_matrix_heatmap_plotly,
|
15 |
plot_similarity_matrix_heatmap_plotly,
|
@@ -45,8 +43,8 @@ def init_aligner(
|
|
45 |
dist_type="cos",
|
46 |
weight_type="uniform",
|
47 |
distortion=distortion,
|
48 |
-
thresh=threshhold,
|
49 |
-
tau=tau,
|
50 |
div_type="--",
|
51 |
)
|
52 |
|
@@ -86,14 +84,14 @@ def main():
|
|
86 |
1.0,
|
87 |
value=0.98,
|
88 |
help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
|
89 |
-
)
|
90 |
threshhold = st.sidebar.slider(
|
91 |
"threshhold: $\lambda$",
|
92 |
0.0,
|
93 |
1.0,
|
94 |
value=0.22,
|
95 |
help="sparsity of alignment matrix",
|
96 |
-
)
|
97 |
show_assignments = st.sidebar.checkbox("show assignments", value=True)
|
98 |
if show_assignments:
|
99 |
n_neighbors = st.sidebar.slider(
|
|
|
8 |
from transformers import AutoModel, AutoTokenizer
|
9 |
|
10 |
from aligner import Aligner
|
|
|
|
|
11 |
from plotools import (
|
12 |
plot_align_matrix_heatmap_plotly,
|
13 |
plot_similarity_matrix_heatmap_plotly,
|
|
|
43 |
dist_type="cos",
|
44 |
weight_type="uniform",
|
45 |
distortion=distortion,
|
46 |
+
thresh=threshhold,
|
47 |
+
tau=tau,
|
48 |
div_type="--",
|
49 |
)
|
50 |
|
|
|
84 |
1.0,
|
85 |
value=0.98,
|
86 |
help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
|
87 |
+
)
|
88 |
threshhold = st.sidebar.slider(
|
89 |
"threshhold: $\lambda$",
|
90 |
0.0,
|
91 |
1.0,
|
92 |
value=0.22,
|
93 |
help="sparsity of alignment matrix",
|
94 |
+
)
|
95 |
show_assignments = st.sidebar.checkbox("show assignments", value=True)
|
96 |
if show_assignments:
|
97 |
n_neighbors = st.sidebar.slider(
|
otfuncs.py
CHANGED
@@ -12,11 +12,11 @@ def compute_distance_matrix_cosine(
|
|
12 |
torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
|
13 |
+ 1.0
|
14 |
) / 2 # Range 0-1
|
15 |
-
C
|
16 |
C = min_max_scaling(C) # Range 0-1
|
17 |
C = 1.0 - C # Convert to distance
|
18 |
|
19 |
-
return C, sim_matrix
|
20 |
|
21 |
|
22 |
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
|
@@ -48,7 +48,7 @@ def apply_distortion(sim_matrix, ratio):
|
|
48 |
|
49 |
sim_matrix = torch.mul(sim_matrix, distortion_mask)
|
50 |
|
51 |
-
return sim_matrix
|
52 |
|
53 |
|
54 |
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
|
|
|
12 |
torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
|
13 |
+ 1.0
|
14 |
) / 2 # Range 0-1
|
15 |
+
C = apply_distortion(sim_matrix, distortion_ratio)
|
16 |
C = min_max_scaling(C) # Range 0-1
|
17 |
C = 1.0 - C # Convert to distance
|
18 |
|
19 |
+
return C, sim_matrix
|
20 |
|
21 |
|
22 |
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
|
|
|
48 |
|
49 |
sim_matrix = torch.mul(sim_matrix, distortion_mask)
|
50 |
|
51 |
+
return sim_matrix
|
52 |
|
53 |
|
54 |
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
|