|
import random |
|
|
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
import umap |
|
from nltk.tokenize import word_tokenize |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
from aligner import Aligner |
|
|
|
|
|
from plotools import ( |
|
plot_align_matrix_heatmap_plotly, |
|
plot_similarity_matrix_heatmap_plotly, |
|
show_assignments_plotly, |
|
) |
|
from utils import centering, convert_to_word_embeddings, encode_sentence |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
random.seed(42) |
|
import nltk |
|
|
|
nltk.download("punkt") |
|
|
|
|
|
@st.cache_resource |
|
def init_model(model: str): |
|
tokenizer = AutoTokenizer.from_pretrained(model) |
|
model = ( |
|
AutoModel.from_pretrained(model, output_hidden_states=True).to(device).eval() |
|
) |
|
return tokenizer, model |
|
|
|
|
|
@st.cache_resource(max_entries=100) |
|
def init_aligner( |
|
ot_type: str, sinkhorn: bool, distortion: float, threshhold: float, tau: float |
|
): |
|
return Aligner( |
|
ot_type=ot_type, |
|
sinkhorn=sinkhorn, |
|
dist_type="cos", |
|
weight_type="uniform", |
|
distortion=distortion, |
|
thresh=threshhold, |
|
tau=tau, |
|
div_type="--", |
|
) |
|
|
|
|
|
def main(): |
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.sidebar.markdown("## Settings & Parameters") |
|
model = st.sidebar.selectbox( |
|
"model", ["microsoft/deberta-v3-base", "bert-base-uncased"] |
|
) |
|
layer = st.sidebar.slider( |
|
"layer number for embeddings", |
|
0, |
|
11, |
|
value=9, |
|
) |
|
is_centering = st.sidebar.checkbox("centering embeddings", value=True) |
|
ot_type = st.sidebar.selectbox( |
|
"ot_type", ["POT", "UOT", "OT"], help="optimal transport algorithm to be used" |
|
) |
|
ot_type = ot_type.lower() |
|
sinkhorn = st.sidebar.checkbox( |
|
"sinkhorn", value=True, help="use sinkhorn algorithm" |
|
) |
|
distortion = st.sidebar.slider( |
|
"distortion: $\kappa$", |
|
0.0, |
|
1.0, |
|
value=0.20, |
|
help="suppression of off-diagonal alignments", |
|
) |
|
tau = st.sidebar.slider( |
|
"m / $\\tau$", |
|
0.0, |
|
1.0, |
|
value=0.98, |
|
help="fraction of fertility to be aligned (fraction of mass to be transported) / penalties", |
|
) |
|
threshhold = st.sidebar.slider( |
|
"threshhold: $\lambda$", |
|
0.0, |
|
1.0, |
|
value=0.22, |
|
help="sparsity of alignment matrix", |
|
) |
|
show_assignments = st.sidebar.checkbox("show assignments", value=True) |
|
if show_assignments: |
|
n_neighbors = st.sidebar.slider( |
|
"n_neighbors", 2, 10, value=8, help="number of neighbors for umap" |
|
) |
|
|
|
|
|
st.markdown( |
|
"## Playground: Unbalanced Optimal Transport for Unbalanced Word Alignment" |
|
) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
sent1 = st.text_area( |
|
"sentence 1", |
|
"By one estimate, fewer than 20,000 lions exist in the wild, a drop of about 40 percent in the past two decades.", |
|
help="Initial text", |
|
) |
|
with col2: |
|
sent2 = st.text_area( |
|
"sentence 2", |
|
"Today there are only around 20,000 wild lions left in the world.", |
|
help="Text to compare", |
|
) |
|
|
|
tokenizer, model = init_model(model) |
|
aligner = init_aligner(ot_type, sinkhorn, distortion, threshhold, tau) |
|
|
|
with st.container(): |
|
if sent1 != '' and sent2 != '': |
|
sent1 = word_tokenize(sent1.lower()) |
|
sent2 = word_tokenize(sent2.lower()) |
|
print(sent1) |
|
print(sent2) |
|
hidden_output, input_id, offset_map = encode_sentence(sent1, sent2, tokenizer, model, layer=layer) |
|
if is_centering: |
|
hidden_output = centering(hidden_output) |
|
s1_vec, s2_vec = convert_to_word_embeddings(offset_map, input_id, hidden_output, tokenizer, pair=True) |
|
align_matrix, cost_matrix, loss, similarity_matrix = aligner.compute_alignment_matrixes(s1_vec, s2_vec) |
|
print(align_matrix.shape, cost_matrix.shape) |
|
|
|
st.write(f"**word alignment matrix** (loss: :blue[{loss}])") |
|
fig = plot_align_matrix_heatmap_plotly(align_matrix.T, sent1, sent2, threshhold, cost_matrix.T) |
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
st.write(f"**word similarity matrix**") |
|
fig2 = plot_similarity_matrix_heatmap_plotly(similarity_matrix.T, sent1, sent2, cost_matrix.T) |
|
st.plotly_chart(fig2, use_container_width=True) |
|
|
|
if show_assignments: |
|
st.write(f"**Alignments after UMAP**") |
|
word_embeddings = torch.vstack([s1_vec, s2_vec]) |
|
umap_embeddings = umap.UMAP( |
|
n_neighbors=n_neighbors, |
|
n_components=2, |
|
random_state=42, |
|
metric="cosine", |
|
).fit_transform(word_embeddings.detach().numpy()) |
|
print(umap_embeddings.shape) |
|
fig3 = show_assignments_plotly( |
|
align_matrix, umap_embeddings, sent1, sent2, thr=threshhold |
|
) |
|
st.plotly_chart(fig3, use_container_width=True) |
|
|
|
st.divider() |
|
st.subheader('Refs') |
|
st.write("Yuki Arase, Han Bao, Sho Yokoi, [Unbalanced Optimal Transport for Unbalanced Word Alignment](https://arxiv.org/abs/2306.04116), ACL2023 [[github](https://github.com/yukiar/OTAlign/tree/main)]") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|