4kasha
commited on
Commit
Ā·
37d364a
1
Parent(s):
94f5fd3
update
Browse files- aligner.py +8 -19
- app.py +85 -42
- otfuncs.py +28 -14
- plotools.py +137 -75
- requirements.txt +2 -1
aligner.py
CHANGED
@@ -10,10 +10,9 @@ from otfuncs import (
|
|
10 |
)
|
11 |
|
12 |
class Aligner:
|
13 |
-
def __init__(self, ot_type, sinkhorn,
|
14 |
self.ot_type = ot_type
|
15 |
self.sinkhorn = sinkhorn
|
16 |
-
self.chimera = chimera
|
17 |
self.dist_type = dist_type
|
18 |
self.weight_type = weight_type
|
19 |
self.distotion = distortion
|
@@ -31,20 +30,19 @@ class Aligner:
|
|
31 |
self.weight_func = compute_weights_norm
|
32 |
|
33 |
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
|
34 |
-
P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
|
35 |
print(log.keys())
|
36 |
if torch.is_tensor(P):
|
37 |
P = P.to('cpu').numpy()
|
38 |
loss = log.get('cost', 'NotImplemented')
|
39 |
|
40 |
-
return P, Cost, loss, similarity_matrix
|
41 |
|
42 |
-
|
43 |
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
|
44 |
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
|
45 |
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
|
46 |
|
47 |
-
C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
|
48 |
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
|
49 |
|
50 |
if self.ot_type == 'ot':
|
@@ -64,14 +62,8 @@ class Aligner:
|
|
64 |
P = min_max_scaling(P)
|
65 |
|
66 |
elif self.ot_type == 'pot':
|
67 |
-
if self.chimera:
|
68 |
-
m = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
|
69 |
-
m = min(1.0, m.item())
|
70 |
-
else:
|
71 |
-
m = self.tau
|
72 |
-
|
73 |
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
|
74 |
-
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) *
|
75 |
|
76 |
if self.sinkhorn:
|
77 |
P, log = ot.partial.entropic_partial_wasserstein(
|
@@ -86,10 +78,7 @@ class Aligner:
|
|
86 |
P = min_max_scaling(P)
|
87 |
|
88 |
elif 'uot' in self.ot_type:
|
89 |
-
|
90 |
-
tau = self.tau * self.bertscore_F1(s1_word_embeddigs, s2_word_embeddigs)
|
91 |
-
else:
|
92 |
-
tau = self.tau
|
93 |
|
94 |
if self.ot_type == 'uot':
|
95 |
P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
|
@@ -107,7 +96,7 @@ class Aligner:
|
|
107 |
elif self.ot_type == 'none':
|
108 |
P = 1 - C
|
109 |
|
110 |
-
return P, C, log, similarity_matrix
|
111 |
|
112 |
def convert_to_numpy(self, s1_weights, s2_weights, C):
|
113 |
if torch.is_tensor(s1_weights):
|
@@ -116,4 +105,4 @@ class Aligner:
|
|
116 |
if torch.is_tensor(C):
|
117 |
C = C.to('cpu').numpy()
|
118 |
|
119 |
-
return s1_weights, s2_weights, C
|
|
|
10 |
)
|
11 |
|
12 |
class Aligner:
|
13 |
+
def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs):
|
14 |
self.ot_type = ot_type
|
15 |
self.sinkhorn = sinkhorn
|
|
|
16 |
self.dist_type = dist_type
|
17 |
self.weight_type = weight_type
|
18 |
self.distotion = distortion
|
|
|
30 |
self.weight_func = compute_weights_norm
|
31 |
|
32 |
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs):
|
33 |
+
P, Cost, log, similarity_matrix, relative_distance = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs)
|
34 |
print(log.keys())
|
35 |
if torch.is_tensor(P):
|
36 |
P = P.to('cpu').numpy()
|
37 |
loss = log.get('cost', 'NotImplemented')
|
38 |
|
39 |
+
return P, Cost, loss, similarity_matrix, relative_distance
|
40 |
|
|
|
41 |
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs):
|
42 |
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64)
|
43 |
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64)
|
44 |
|
45 |
+
C, similarity_matrix, relative_distance = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion)
|
46 |
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs)
|
47 |
|
48 |
if self.ot_type == 'ot':
|
|
|
62 |
P = min_max_scaling(P)
|
63 |
|
64 |
elif self.ot_type == 'pot':
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C)
|
66 |
+
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau
|
67 |
|
68 |
if self.sinkhorn:
|
69 |
P, log = ot.partial.entropic_partial_wasserstein(
|
|
|
78 |
P = min_max_scaling(P)
|
79 |
|
80 |
elif 'uot' in self.ot_type:
|
81 |
+
tau = self.tau
|
|
|
|
|
|
|
82 |
|
83 |
if self.ot_type == 'uot':
|
84 |
P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(
|
|
|
96 |
elif self.ot_type == 'none':
|
97 |
P = 1 - C
|
98 |
|
99 |
+
return P, C, log, similarity_matrix, relative_distance
|
100 |
|
101 |
def convert_to_numpy(self, s1_weights, s2_weights, C):
|
102 |
if torch.is_tensor(s1_weights):
|
|
|
105 |
if torch.is_tensor(C):
|
106 |
C = C.to('cpu').numpy()
|
107 |
|
108 |
+
return s1_weights, s2_weights, C
|
app.py
CHANGED
@@ -1,44 +1,53 @@
|
|
1 |
-
import streamlit as st
|
2 |
import random
|
|
|
3 |
import numpy as np
|
|
|
4 |
import torch
|
|
|
5 |
from nltk.tokenize import word_tokenize
|
6 |
-
from transformers import
|
|
|
7 |
from aligner import Aligner
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
)
|
13 |
-
from
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
torch.manual_seed(42)
|
17 |
np.random.seed(42)
|
18 |
random.seed(42)
|
19 |
import nltk
|
20 |
-
|
|
|
21 |
|
22 |
|
23 |
@st.cache_resource
|
24 |
def init_model(model: str):
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model)
|
26 |
-
model =
|
|
|
|
|
27 |
return tokenizer, model
|
28 |
|
29 |
|
30 |
@st.cache_resource(max_entries=100)
|
31 |
-
def init_aligner(
|
|
|
|
|
32 |
return Aligner(
|
33 |
ot_type=ot_type,
|
34 |
sinkhorn=sinkhorn,
|
35 |
-
chimera=False,
|
36 |
dist_type="cos",
|
37 |
weight_type="uniform",
|
38 |
distortion=distortion,
|
39 |
-
thresh=threshhold,
|
40 |
-
tau=tau,
|
41 |
-
div_type="--"
|
42 |
)
|
43 |
|
44 |
|
@@ -47,51 +56,70 @@ def main():
|
|
47 |
|
48 |
# Sidebar
|
49 |
st.sidebar.markdown("## Settings & Parameters")
|
50 |
-
model = st.sidebar.selectbox(
|
|
|
|
|
51 |
layer = st.sidebar.slider(
|
52 |
-
|
|
|
|
|
|
|
53 |
)
|
54 |
-
is_centering = st.sidebar.checkbox(
|
55 |
ot_type = st.sidebar.selectbox(
|
56 |
-
|
57 |
-
help="optimal transport algorithm to be used"
|
58 |
)
|
59 |
ot_type = ot_type.lower()
|
60 |
sinkhorn = st.sidebar.checkbox(
|
61 |
-
|
62 |
-
help="use sinkhorn algorithm"
|
63 |
)
|
64 |
distortion = st.sidebar.slider(
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
)
|
68 |
tau = st.sidebar.slider(
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
72 |
threshhold = st.sidebar.slider(
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# Content
|
78 |
-
st.markdown(
|
|
|
|
|
79 |
|
80 |
col1, col2 = st.columns(2)
|
81 |
|
82 |
with col1:
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
with col2:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
tokenizer, model = init_model(model)
|
96 |
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
|
97 |
|
@@ -115,10 +143,25 @@ def main():
|
|
115 |
st.write(f"**word similarity matrix**")
|
116 |
fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
|
117 |
st.plotly_chart(fig2, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
st.divider()
|
120 |
st.subheader('Refs')
|
121 |
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
|
122 |
|
123 |
if __name__ == '__main__':
|
124 |
-
main()
|
|
|
|
|
1 |
import random
|
2 |
+
|
3 |
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
import torch
|
6 |
+
import umap
|
7 |
from nltk.tokenize import word_tokenize
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
+
|
10 |
from aligner import Aligner
|
11 |
+
|
12 |
+
# from utils import align_matrix_heatmap, plot_align_matrix_heatmap
|
13 |
+
from plotools import (
|
14 |
+
plot_align_matrix_heatmap_plotly,
|
15 |
+
plot_similarity_matrix_heatmap_plotly,
|
16 |
+
show_assignments_plotly,
|
17 |
)
|
18 |
+
from utils import centering, convert_to_word_embeddings, encode_sentence
|
19 |
|
20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
torch.manual_seed(42)
|
22 |
np.random.seed(42)
|
23 |
random.seed(42)
|
24 |
import nltk
|
25 |
+
|
26 |
+
nltk.download("punkt")
|
27 |
|
28 |
|
29 |
@st.cache_resource
|
30 |
def init_model(model: str):
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model)
|
32 |
+
model = (
|
33 |
+
AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval()
|
34 |
+
)
|
35 |
return tokenizer, model
|
36 |
|
37 |
|
38 |
@st.cache_resource(max_entries=100)
|
39 |
+
def init_aligner(
|
40 |
+
ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float
|
41 |
+
):
|
42 |
return Aligner(
|
43 |
ot_type=ot_type,
|
44 |
sinkhorn=sinkhorn,
|
|
|
45 |
dist_type="cos",
|
46 |
weight_type="uniform",
|
47 |
distortion=distortion,
|
48 |
+
thresh=threshhold, # 0.25252525252525254
|
49 |
+
tau=tau, # 0.9803921568627451
|
50 |
+
div_type="--",
|
51 |
)
|
52 |
|
53 |
|
|
|
56 |
|
57 |
# Sidebar
|
58 |
st.sidebar.markdown("## Settings & Parameters")
|
59 |
+
model = st.sidebar.selectbox(
|
60 |
+
"model", ["microsoft/deberta-v3-base", "bert-base-uncased"]
|
61 |
+
)
|
62 |
layer = st.sidebar.slider(
|
63 |
+
"layer number for embeddings",
|
64 |
+
0,
|
65 |
+
11,
|
66 |
+
value=9,
|
67 |
)
|
68 |
+
is_centering = st.sidebar.checkbox("centering embeddings", value=True)
|
69 |
ot_type = st.sidebar.selectbox(
|
70 |
+
"ot_type", ["POT", "UOT", "OT"], help="optimal transport algorithm to be used"
|
|
|
71 |
)
|
72 |
ot_type = ot_type.lower()
|
73 |
sinkhorn = st.sidebar.checkbox(
|
74 |
+
"sinkhorn", value=True, help="use sinkhorn algorithm"
|
|
|
75 |
)
|
76 |
distortion = st.sidebar.slider(
|
77 |
+
"distortion: $\kappa$",
|
78 |
+
0.0,
|
79 |
+
1.0,
|
80 |
+
value=0.20,
|
81 |
+
help="suppression of off-diagonal alignments",
|
82 |
)
|
83 |
tau = st.sidebar.slider(
|
84 |
+
"m / $\\tau$",
|
85 |
+
0.0,
|
86 |
+
1.0,
|
87 |
+
value=0.98,
|
88 |
+
help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties",
|
89 |
+
) # with 0.02 interva
|
90 |
threshhold = st.sidebar.slider(
|
91 |
+
"threshhold: $\lambda$",
|
92 |
+
0.0,
|
93 |
+
1.0,
|
94 |
+
value=0.22,
|
95 |
+
help="sparsity of alignment matrix",
|
96 |
+
) # with 0.01 interval
|
97 |
+
show_assignments = st.sidebar.checkbox("show assignments", value=True)
|
98 |
+
if show_assignments:
|
99 |
+
n_neighbors = st.sidebar.slider(
|
100 |
+
"n_neighbors", 2, 10, value=8, help="number of neighbors for umap"
|
101 |
+
)
|
102 |
|
103 |
# Content
|
104 |
+
st.markdown(
|
105 |
+
"## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment"
|
106 |
+
)
|
107 |
|
108 |
col1, col2 = st.columns(2)
|
109 |
|
110 |
with col1:
|
111 |
+
sent1 = st.text_area(
|
112 |
+
"sentence 1",
|
113 |
+
"By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.",
|
114 |
+
help="Initial text",
|
115 |
+
)
|
116 |
with col2:
|
117 |
+
sent2 = st.text_area(
|
118 |
+
"sentence 2",
|
119 |
+
"Today there are only around 20,000 wild lions left in the world.",
|
120 |
+
help="Text to compare",
|
121 |
+
)
|
122 |
+
|
123 |
tokenizer, model = init_model(model)
|
124 |
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau)
|
125 |
|
|
|
143 |
st.write(f"**word similarity matrix**")
|
144 |
fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T)
|
145 |
st.plotly_chart(fig2, use_container_width=True)
|
146 |
+
|
147 |
+
if show_assignments:
|
148 |
+
st.write(f"**Alignments after UMAP**")
|
149 |
+
word_embeddings = torch.vstack([s1_vec, s2_vec])
|
150 |
+
umap_embeddings = umap.UMAP(
|
151 |
+
n_neighbors=n_neighbors,
|
152 |
+
n_components=2,
|
153 |
+
random_state=42,
|
154 |
+
metric="cosine",
|
155 |
+
).fit_transform(word_embeddings.detach().numpy())
|
156 |
+
print(umap_embeddings.shape)
|
157 |
+
fig3 = show_assignments_plotly(
|
158 |
+
align_matrix, umap_embeddings, sent1, sent2, thr=threshhold
|
159 |
+
)
|
160 |
+
st.plotly_chart(fig3, use_container_width=True)
|
161 |
|
162 |
st.divider()
|
163 |
st.subheader('Refs')
|
164 |
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]")
|
165 |
|
166 |
if __name__ == '__main__':
|
167 |
+
main()
|
otfuncs.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1 |
-
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
from ot.backend import get_backend
|
5 |
|
6 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
C = min_max_scaling(C) # Range 0-1
|
12 |
C = 1.0 - C # Convert to distance
|
13 |
|
14 |
-
return C, sim_matrix
|
15 |
|
16 |
|
17 |
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
|
@@ -30,15 +35,20 @@ def apply_distortion(sim_matrix, ratio):
|
|
30 |
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
|
31 |
return sim_matrix
|
32 |
|
33 |
-
pos_x = torch.tensor(
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
sim_matrix = torch.mul(sim_matrix, distortion_mask)
|
40 |
|
41 |
-
return sim_matrix
|
42 |
|
43 |
|
44 |
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
|
@@ -48,8 +58,12 @@ def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
|
|
48 |
|
49 |
|
50 |
def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
|
51 |
-
s1_weights = torch.ones(
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# # Uniform weights to make L2 norm=1
|
55 |
# s1_weights /= torch.linalg.norm(s1_weights)
|
@@ -65,4 +79,4 @@ def min_max_scaling(C):
|
|
65 |
C_min = nx.min(C)
|
66 |
C_max = nx.max(C)
|
67 |
C = (C - C_min + eps) / (C_max - C_min + eps)
|
68 |
-
return C
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from ot.backend import get_backend
|
4 |
|
5 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
6 |
|
7 |
+
|
8 |
+
def compute_distance_matrix_cosine(
|
9 |
+
s1_word_embeddigs, s2_word_embeddigs, distortion_ratio
|
10 |
+
):
|
11 |
+
sim_matrix = (
|
12 |
+
torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
|
13 |
+
+ 1.0
|
14 |
+
) / 2 # Range 0-1
|
15 |
+
C, relative_distance = apply_distortion(sim_matrix, distortion_ratio)
|
16 |
C = min_max_scaling(C) # Range 0-1
|
17 |
C = 1.0 - C # Convert to distance
|
18 |
|
19 |
+
return C, sim_matrix, relative_distance
|
20 |
|
21 |
|
22 |
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
|
|
|
35 |
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
|
36 |
return sim_matrix
|
37 |
|
38 |
+
pos_x = torch.tensor(
|
39 |
+
[[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
|
40 |
+
device=device,
|
41 |
+
)
|
42 |
+
pos_y = torch.tensor(
|
43 |
+
[[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
|
44 |
+
device=device,
|
45 |
+
)
|
46 |
+
relative_distance = (pos_x - pos_y.T) ** 2
|
47 |
+
distortion_mask = 1.0 - relative_distance * ratio
|
48 |
|
49 |
sim_matrix = torch.mul(sim_matrix, distortion_mask)
|
50 |
|
51 |
+
return sim_matrix, relative_distance
|
52 |
|
53 |
|
54 |
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
|
|
|
58 |
|
59 |
|
60 |
def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
|
61 |
+
s1_weights = torch.ones(
|
62 |
+
s1_word_embeddigs.shape[0], dtype=torch.float64, device=device
|
63 |
+
)
|
64 |
+
s2_weights = torch.ones(
|
65 |
+
s2_word_embeddigs.shape[0], dtype=torch.float64, device=device
|
66 |
+
)
|
67 |
|
68 |
# # Uniform weights to make L2 norm=1
|
69 |
# s1_weights /= torch.linalg.norm(s1_weights)
|
|
|
79 |
C_min = nx.min(C)
|
80 |
C_max = nx.max(C)
|
81 |
C = (C - C_min + eps) / (C_max - C_min + eps)
|
82 |
+
return C
|
plotools.py
CHANGED
@@ -8,74 +8,79 @@ def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]):
|
|
8 |
using zero-width-space
|
9 |
cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
|
10 |
"""
|
11 |
-
sent1 = [word + i*
|
12 |
-
sent2 = [word + i*
|
13 |
-
|
14 |
return sent1, sent2
|
15 |
|
16 |
|
17 |
def discrete_colorscale(bvals, colors):
|
18 |
"""
|
19 |
bvals - list of values bounding intervals/ranges of interest
|
20 |
-
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0
|
21 |
returns the plotly discrete colorscale
|
22 |
ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
|
23 |
"""
|
24 |
-
if len(bvals) != len(colors)+1:
|
25 |
-
raise ValueError(
|
26 |
-
bvals = sorted(bvals)
|
27 |
-
nvals = [
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
for k in range(len(colors)):
|
31 |
-
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
|
32 |
-
return dcolorscale
|
33 |
|
34 |
|
35 |
def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
|
36 |
align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
|
37 |
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
|
38 |
-
_colors = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
_ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
40 |
|
41 |
colorscale = discrete_colorscale(_ticks, _colors)
|
42 |
|
43 |
fig = go.Figure()
|
44 |
-
|
45 |
-
fig.add_trace(
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
tick0=0,
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
'cost: %{customdata:.3f} ',
|
63 |
-
name=''
|
64 |
-
))
|
65 |
fig.update_layout(
|
66 |
-
#xaxis=dict(scaleanchor='y'),
|
67 |
-
yaxis=dict(autorange=
|
68 |
-
margin={
|
69 |
-
plot_bgcolor=
|
70 |
font=dict(
|
71 |
size=16,
|
72 |
),
|
73 |
hoverlabel=dict(
|
74 |
-
bgcolor="#555",
|
75 |
-
|
76 |
-
font_size=14,
|
77 |
-
font_family="Open Sans"
|
78 |
-
)
|
79 |
)
|
80 |
fig.update_xaxes(
|
81 |
tickangle=-45,
|
@@ -83,47 +88,104 @@ def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cos
|
|
83 |
return fig
|
84 |
|
85 |
|
86 |
-
def plot_similarity_matrix_heatmap_plotly(
|
|
|
|
|
87 |
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
|
88 |
|
89 |
fig = go.Figure()
|
90 |
-
|
91 |
-
fig.add_trace(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
tick0=0,
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
name=''
|
110 |
-
))
|
111 |
fig.update_layout(
|
112 |
-
#xaxis=dict(scaleanchor='y'),
|
113 |
-
yaxis=dict(autorange=
|
114 |
-
margin={
|
115 |
-
plot_bgcolor=
|
116 |
font=dict(
|
117 |
size=16,
|
118 |
),
|
119 |
hoverlabel=dict(
|
120 |
-
bgcolor="#555",
|
121 |
-
|
122 |
-
font_size=14,
|
123 |
-
font_family="Open Sans"
|
124 |
-
)
|
125 |
)
|
126 |
fig.update_xaxes(
|
127 |
tickangle=-45,
|
128 |
)
|
129 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
using zero-width-space
|
9 |
cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013
|
10 |
"""
|
11 |
+
sent1 = [word + i * "\u200b" for i, word in enumerate(sent1)]
|
12 |
+
sent2 = [word + i * "\u200b" for i, word in enumerate(sent2)]
|
13 |
+
|
14 |
return sent1, sent2
|
15 |
|
16 |
|
17 |
def discrete_colorscale(bvals, colors):
|
18 |
"""
|
19 |
bvals - list of values bounding intervals/ranges of interest
|
20 |
+
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
|
21 |
returns the plotly discrete colorscale
|
22 |
ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780
|
23 |
"""
|
24 |
+
if len(bvals) != len(colors) + 1:
|
25 |
+
raise ValueError("len(boundary values) should be equal to len(colors)+1")
|
26 |
+
bvals = sorted(bvals)
|
27 |
+
nvals = [
|
28 |
+
(v - bvals[0]) / (bvals[-1] - bvals[0]) for v in bvals
|
29 |
+
] # normalized values
|
30 |
+
|
31 |
+
dcolorscale = [] # discrete colorscale
|
32 |
for k in range(len(colors)):
|
33 |
+
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k + 1], colors[k]]])
|
34 |
+
return dcolorscale
|
35 |
|
36 |
|
37 |
def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost):
|
38 |
align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix)
|
39 |
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
|
40 |
+
_colors = [
|
41 |
+
"#F2F2F2",
|
42 |
+
"#E0F4FA",
|
43 |
+
"#BEE4F0",
|
44 |
+
"#88CCE5",
|
45 |
+
"#33b7df",
|
46 |
+
"#1B88A6",
|
47 |
+
"#105264",
|
48 |
+
"#092E39",
|
49 |
+
]
|
50 |
_ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
51 |
|
52 |
colorscale = discrete_colorscale(_ticks, _colors)
|
53 |
|
54 |
fig = go.Figure()
|
55 |
+
|
56 |
+
fig.add_trace(
|
57 |
+
go.Heatmap(
|
58 |
+
z=align_matrix,
|
59 |
+
customdata=Cost,
|
60 |
+
x=sent1,
|
61 |
+
y=sent2,
|
62 |
+
xgap=2,
|
63 |
+
ygap=2,
|
64 |
+
colorscale=colorscale,
|
65 |
+
colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0),
|
66 |
+
hovertemplate="x: %{x}<br>"
|
67 |
+
+ "y: %{y}<br>"
|
68 |
+
+ "P: %{z:.3f}<br>"
|
69 |
+
+ "cost: %{customdata:.3f} ",
|
70 |
+
name="",
|
71 |
+
)
|
72 |
+
)
|
|
|
|
|
|
|
73 |
fig.update_layout(
|
74 |
+
# xaxis=dict(scaleanchor='y'),
|
75 |
+
yaxis=dict(autorange="reversed"),
|
76 |
+
margin={"l": 0, "r": 0, "t": 0, "b": 0},
|
77 |
+
plot_bgcolor="rgba(0,0,0,0)",
|
78 |
font=dict(
|
79 |
size=16,
|
80 |
),
|
81 |
hoverlabel=dict(
|
82 |
+
bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans"
|
83 |
+
),
|
|
|
|
|
|
|
84 |
)
|
85 |
fig.update_xaxes(
|
86 |
tickangle=-45,
|
|
|
88 |
return fig
|
89 |
|
90 |
|
91 |
+
def plot_similarity_matrix_heatmap_plotly(
|
92 |
+
similarity_matrix, sent1, sent2, Cost, colorscale="Reds", hover_z="cosine"
|
93 |
+
):
|
94 |
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2)
|
95 |
|
96 |
fig = go.Figure()
|
97 |
+
|
98 |
+
fig.add_trace(
|
99 |
+
go.Heatmap(
|
100 |
+
z=similarity_matrix,
|
101 |
+
customdata=Cost,
|
102 |
+
x=sent1,
|
103 |
+
y=sent2,
|
104 |
+
xgap=2,
|
105 |
+
ygap=2,
|
106 |
+
colorscale=colorscale,
|
107 |
+
colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0),
|
108 |
+
hovertemplate="x: %{x}<br>"
|
109 |
+
+ "y: %{y}<br>"
|
110 |
+
+ f"{hover_z}: "
|
111 |
+
+ "%{z:.3f}<br>"
|
112 |
+
+ "cost: %{customdata:.3f} ",
|
113 |
+
name="",
|
114 |
+
)
|
115 |
+
)
|
|
|
|
|
116 |
fig.update_layout(
|
117 |
+
# xaxis=dict(scaleanchor='y'),
|
118 |
+
yaxis=dict(autorange="reversed"),
|
119 |
+
margin={"l": 0, "r": 0, "t": 0, "b": 0},
|
120 |
+
plot_bgcolor="rgba(0,0,0,0)",
|
121 |
font=dict(
|
122 |
size=16,
|
123 |
),
|
124 |
hoverlabel=dict(
|
125 |
+
bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans"
|
126 |
+
),
|
|
|
|
|
|
|
127 |
)
|
128 |
fig.update_xaxes(
|
129 |
tickangle=-45,
|
130 |
)
|
131 |
+
return fig
|
132 |
+
|
133 |
+
|
134 |
+
def show_assignments_plotly(P, word_embeddings, sents1, sents2, thr=0):
|
135 |
+
P = np.where(P <= thr, 0, P)
|
136 |
+
|
137 |
+
s1_end = len(sents1)
|
138 |
+
a = word_embeddings[:s1_end]
|
139 |
+
b = word_embeddings[s1_end:]
|
140 |
+
|
141 |
+
traces = []
|
142 |
+
sample = 0
|
143 |
+
|
144 |
+
for i in range(a.shape[0]):
|
145 |
+
for j in range(b.shape[0]):
|
146 |
+
if P[i, j] > 0:
|
147 |
+
sample += 1
|
148 |
+
traces.append(
|
149 |
+
go.Scatter(
|
150 |
+
x=[a[i, 0], b[j, 0]],
|
151 |
+
y=[a[i, 1], b[j, 1]],
|
152 |
+
mode="lines",
|
153 |
+
line=dict(color="black", width=P[i, j] * 2),
|
154 |
+
opacity=P[i, j],
|
155 |
+
name=f"{sample}",
|
156 |
+
)
|
157 |
+
)
|
158 |
+
|
159 |
+
# ć½ć¼ć¹ćµć³ćć«ć®ęē»
|
160 |
+
traces.append(
|
161 |
+
go.Scatter(
|
162 |
+
x=a[:, 0],
|
163 |
+
y=a[:, 1],
|
164 |
+
mode="markers+text",
|
165 |
+
marker=dict(color="blue", size=8, symbol="cross"),
|
166 |
+
text=sents1,
|
167 |
+
textposition="top center",
|
168 |
+
name="Source samples",
|
169 |
+
)
|
170 |
+
)
|
171 |
+
|
172 |
+
# ćæć¼ć²ćććµć³ćć«ć®ęē»
|
173 |
+
traces.append(
|
174 |
+
go.Scatter(
|
175 |
+
x=b[:, 0],
|
176 |
+
y=b[:, 1],
|
177 |
+
mode="markers+text",
|
178 |
+
marker=dict(color="red", size=8, symbol="x"),
|
179 |
+
text=sents2,
|
180 |
+
textposition="bottom center",
|
181 |
+
name="Target samples",
|
182 |
+
)
|
183 |
+
)
|
184 |
+
|
185 |
+
layout = go.Layout(
|
186 |
+
showlegend=True,
|
187 |
+
margin=dict(l=0, r=0, t=10, b=0),
|
188 |
+
)
|
189 |
+
|
190 |
+
fig = go.Figure(data=traces, layout=layout)
|
191 |
+
return fig
|
requirements.txt
CHANGED
@@ -6,4 +6,5 @@ transformers==4.30.2
|
|
6 |
matplotlib==3.7.1
|
7 |
plotly==5.15.0
|
8 |
torch==2.0.1
|
9 |
-
nltk==3.8.1
|
|
|
|
6 |
matplotlib==3.7.1
|
7 |
plotly==5.15.0
|
8 |
torch==2.0.1
|
9 |
+
nltk==3.8.1
|
10 |
+
umap-learn==0.5.5
|