|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
def generate_diverging_colors(num_colors, palette='Set3'): |
|
|
|
cmap = plt.cm.get_cmap(palette, num_colors) |
|
|
|
|
|
colors_rgb = cmap(np.arange(num_colors)) |
|
|
|
|
|
colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] |
|
|
|
return colors_hex |
|
|
|
|
|
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, |
|
threshold=0.4, skip_first_src=True, skip_second_src=False, |
|
layer=2, head=6): |
|
|
|
alignment = [] |
|
|
|
for i, tok in enumerate(outputs.cross_attentions[layer][0][head]): |
|
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
merged = [] |
|
for i in alignment: |
|
token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] |
|
|
|
if token not in ["</s>", "<pad>", "<unk>", "<s>"]: |
|
if merged: |
|
tomerge = False |
|
|
|
for x in i[1]: |
|
if x in merged[-1][1]: |
|
tomerge = True |
|
break |
|
|
|
if token[0] != "▁": |
|
tomerge = True |
|
if tomerge: |
|
merged[-1][0] += i[0] |
|
merged[-1][1] += i[1] |
|
else: |
|
merged.append(i) |
|
else: |
|
merged.append(i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colordict = {} |
|
ncolors = 0 |
|
for i in merged: |
|
src_tok = [f"src_{x}" for x in i[0]] |
|
trg_tok = [f"trg_{x}" for x in i[1]] |
|
all_tok = src_tok + trg_tok |
|
|
|
newcolor = None |
|
for t in all_tok: |
|
if t in colordict: |
|
newcolor = colordict[t] |
|
break |
|
if not newcolor: |
|
newcolor = ncolors |
|
ncolors += 1 |
|
for t in all_tok: |
|
if t not in colordict: |
|
colordict[t] = newcolor |
|
|
|
colors = generate_diverging_colors(ncolors, palette="Set2") |
|
id_to_color = {i: c for i, c in enumerate(colors)} |
|
for k, v in colordict.items(): |
|
colordict[k] = id_to_color[v] |
|
|
|
tgthtml = [] |
|
for i, token in enumerate(decoder_input_ids[0]): |
|
if f"src_{i}" in colordict: |
|
label = f"src_{i}" |
|
tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") |
|
else: |
|
tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") |
|
tgthtml = "".join(tgthtml) |
|
tgthtml = tgthtml.replace("▁", " ") |
|
tgthtml = f"<span style='font-size: 25px'>{tgthtml}</span>" |
|
|
|
srchtml = [] |
|
for i, token in enumerate(encoder_input_ids[0]): |
|
if (i == 0 and skip_first_src) or (i == 1 and skip_second_src): |
|
continue |
|
|
|
if f"trg_{i}" in colordict: |
|
label = f"trg_{i}" |
|
srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") |
|
else: |
|
srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") |
|
srchtml = "".join(srchtml) |
|
srchtml = srchtml.replace("▁", " ") |
|
srchtml = f"<span style='font-size: 25px'>{srchtml}</span>" |
|
return srchtml, tgthtml |
|
|