|
import numpy as np |
|
import plotly.graph_objects as go |
|
|
|
|
|
def _debug_non_unique_axis_values(sent1: list[str], sent2: list[str]): |
|
""" |
|
solution: |
|
using zero-width-space |
|
cf. https://github.com/plotly/plotly.js/issues/1516#issuecomment-983090013 |
|
""" |
|
sent1 = [word + i * "\u200b" for i, word in enumerate(sent1)] |
|
sent2 = [word + i * "\u200b" for i, word in enumerate(sent2)] |
|
|
|
return sent1, sent2 |
|
|
|
|
|
def discrete_colorscale(bvals, colors): |
|
""" |
|
bvals - list of values bounding intervals/ranges of interest |
|
colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1 |
|
returns the plotly discrete colorscale |
|
ref. https://community.plotly.com/t/colors-for-discrete-ranges-in-heatmaps/7780 |
|
""" |
|
if len(bvals) != len(colors) + 1: |
|
raise ValueError("len(boundary values) should be equal to len(colors)+1") |
|
bvals = sorted(bvals) |
|
nvals = [ |
|
(v - bvals[0]) / (bvals[-1] - bvals[0]) for v in bvals |
|
] |
|
|
|
dcolorscale = [] |
|
for k in range(len(colors)): |
|
dcolorscale.extend([[nvals[k], colors[k]], [nvals[k + 1], colors[k]]]) |
|
return dcolorscale |
|
|
|
|
|
def plot_align_matrix_heatmap_plotly(align_matrix, sent1, sent2, threshhold, Cost): |
|
align_matrix = np.where(align_matrix <= threshhold, 0, align_matrix) |
|
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2) |
|
_colors = [ |
|
"#F2F2F2", |
|
"#E0F4FA", |
|
"#BEE4F0", |
|
"#88CCE5", |
|
"#33b7df", |
|
"#1B88A6", |
|
"#105264", |
|
"#092E39", |
|
] |
|
_ticks = [0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] |
|
|
|
colorscale = discrete_colorscale(_ticks, _colors) |
|
|
|
fig = go.Figure() |
|
|
|
fig.add_trace( |
|
go.Heatmap( |
|
z=align_matrix, |
|
customdata=Cost, |
|
x=sent1, |
|
y=sent2, |
|
xgap=2, |
|
ygap=2, |
|
colorscale=colorscale, |
|
colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0), |
|
hovertemplate="x: %{x}<br>" |
|
+ "y: %{y}<br>" |
|
+ "P: %{z:.3f}<br>" |
|
+ "cost: %{customdata:.3f} ", |
|
name="", |
|
) |
|
) |
|
fig.update_layout( |
|
|
|
yaxis=dict(autorange="reversed"), |
|
margin={"l": 0, "r": 0, "t": 0, "b": 0}, |
|
plot_bgcolor="rgba(0,0,0,0)", |
|
font=dict( |
|
size=16, |
|
), |
|
hoverlabel=dict( |
|
bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans" |
|
), |
|
) |
|
fig.update_xaxes( |
|
tickangle=-45, |
|
) |
|
return fig |
|
|
|
|
|
def plot_similarity_matrix_heatmap_plotly( |
|
similarity_matrix, sent1, sent2, Cost, colorscale="Reds", hover_z="cosine" |
|
): |
|
sent1, sent2 = _debug_non_unique_axis_values(sent1, sent2) |
|
|
|
fig = go.Figure() |
|
|
|
fig.add_trace( |
|
go.Heatmap( |
|
z=similarity_matrix, |
|
customdata=Cost, |
|
x=sent1, |
|
y=sent2, |
|
xgap=2, |
|
ygap=2, |
|
colorscale=colorscale, |
|
colorbar=dict(tick0=0, dtick=0.125, outlinewidth=0), |
|
hovertemplate="x: %{x}<br>" |
|
+ "y: %{y}<br>" |
|
+ f"{hover_z}: " |
|
+ "%{z:.3f}<br>" |
|
+ "cost: %{customdata:.3f} ", |
|
name="", |
|
) |
|
) |
|
fig.update_layout( |
|
|
|
yaxis=dict(autorange="reversed"), |
|
margin={"l": 0, "r": 0, "t": 0, "b": 0}, |
|
plot_bgcolor="rgba(0,0,0,0)", |
|
font=dict( |
|
size=16, |
|
), |
|
hoverlabel=dict( |
|
bgcolor="#555", font_color="white", font_size=14, font_family="Open Sans" |
|
), |
|
) |
|
fig.update_xaxes( |
|
tickangle=-45, |
|
) |
|
return fig |
|
|
|
|
|
def show_assignments_plotly(P, word_embeddings, sents1, sents2, thr=0): |
|
P = np.where(P <= thr, 0, P) |
|
|
|
s1_end = len(sents1) |
|
a = word_embeddings[:s1_end] |
|
b = word_embeddings[s1_end:] |
|
|
|
traces = [] |
|
sample = 0 |
|
|
|
for i in range(a.shape[0]): |
|
for j in range(b.shape[0]): |
|
if P[i, j] > 0: |
|
sample += 1 |
|
traces.append( |
|
go.Scatter( |
|
x=[a[i, 0], b[j, 0]], |
|
y=[a[i, 1], b[j, 1]], |
|
mode="lines", |
|
line=dict(color="black", width=P[i, j] * 2), |
|
opacity=P[i, j], |
|
name=f"{sample}", |
|
) |
|
) |
|
|
|
|
|
traces.append( |
|
go.Scatter( |
|
x=a[:, 0], |
|
y=a[:, 1], |
|
mode="markers+text", |
|
marker=dict(color="blue", size=8, symbol="cross"), |
|
text=sents1, |
|
textposition="top center", |
|
name="Source samples", |
|
) |
|
) |
|
|
|
|
|
traces.append( |
|
go.Scatter( |
|
x=b[:, 0], |
|
y=b[:, 1], |
|
mode="markers+text", |
|
marker=dict(color="red", size=8, symbol="x"), |
|
text=sents2, |
|
textposition="bottom center", |
|
name="Target samples", |
|
) |
|
) |
|
|
|
layout = go.Layout( |
|
showlegend=True, |
|
margin=dict(l=0, r=0, t=10, b=0), |
|
) |
|
|
|
fig = go.Figure(data=traces, layout=layout) |
|
return fig |
|
|