value-zeroing / app.py
Martijn van Beers
Initial implementation
9f74b46
import re
import pandas
import seaborn
import gradio
import pathlib
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy
from sklearn.metrics.pairwise import cosine_distances
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForSequenceClassification, AutoModelForMaskedLM
)
## Rollout Helper Function
def compute_joint_attention(att_mat, res=True):
if res:
residual_att = numpy.eye(att_mat.shape[1])[None,...]
att_mat = att_mat + residual_att
att_mat = att_mat / att_mat.sum(axis=-1)[...,None]
joint_attentions = numpy.zeros(att_mat.shape)
layers = joint_attentions.shape[0]
joint_attentions[0] = att_mat[0]
for i in numpy.arange(1,layers):
joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])
return joint_attentions
def create_plot(all_tokens, score_data):
LAYERS = list(range(12))
fig, axs = plt.subplots(6, 2, figsize=(8, 24))
plt.subplots_adjust(top=0.98, bottom=0.05, hspace=0.5, wspace=0.5)
for layer in LAYERS:
a = (layer)//2
b = layer%2
seaborn.heatmap(
ax=axs[a, b],
data=pandas.DataFrame(score_data[layer], index= all_tokens, columns=all_tokens),
cmap="Blues",
annot=False,
cbar=False
)
axs[a, b].set_title(f"Layer: {layer+1}")
return fig
matplotlib.use('agg')
DISTANCE_FUNC = {
'cosine': cosine_distances
}
MODEL_PATH = {
'bert': 'bert-base-uncased',
'roberta': 'roberta-base',
}
MODEL_NAME = 'bert'
#MODEL_NAME = 'roberta'
METRIC = 'cosine'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
def run(mname, sent):
global MODEL_NAME, config, model, tokenizer
if mname != MODEL_NAME:
MODEL_NAME = mname
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
sent = re.sub(r".MASK.", tokenizer.mask_token, sent)
inputs = tokenizer(sent, return_token_type_ids=True, return_tensors="pt")
## Cpmpute: layerwise value zeroing
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(inputs['input_ids'],
attention_mask=inputs['attention_mask'],
token_type_ids=inputs['token_type_ids'],
output_hidden_states=True, output_attentions=False)
org_hidden_states = torch.stack(outputs['hidden_states']).squeeze(1)
input_shape = inputs['input_ids'].size()
batch_size, seq_length = input_shape
score_matrix = numpy.zeros((config.num_hidden_layers, seq_length, seq_length))
for l, layer_module in enumerate(getattr(model, MODEL_NAME).encoder.layer):
for t in range(seq_length):
extended_blanking_attention_mask: torch.Tensor = getattr(model, MODEL_NAME).get_extended_attention_mask(inputs['attention_mask'], input_shape, device)
with torch.no_grad():
layer_outputs = layer_module(org_hidden_states[l].unsqueeze(0), # previous layer's original output
attention_mask=extended_blanking_attention_mask,
output_attentions=False,
zero_value_index=t,
)
hidden_states = layer_outputs[0].squeeze().detach().cpu().numpy()
# compute similarity between original and new outputs
# cosine
x = hidden_states
y = org_hidden_states[l+1].detach().cpu().numpy()
distances = DISTANCE_FUNC[METRIC](x, y).diagonal()
score_matrix[l, :, t] = distances
valuezeroing_scores = score_matrix / numpy.sum(score_matrix, axis=-1, keepdims=True)
rollout_valuezeroing_scores = compute_joint_attention(valuezeroing_scores, res=False)
# Plot:
cmap = "Blues"
all_tokens = [tokenizer.convert_ids_to_tokens(t) for t in inputs['input_ids']]
rollout_fig = create_plot(all_tokens, rollout_valuezeroing_scores)
value_fig = create_plot(all_tokens, valuezeroing_scores)
return rollout_fig, value_fig
examples = pandas.read_csv("examples.csv").to_numpy().tolist()
with gradio.Blocks(
title="Differences with/without zero-valuing",
css= ".output-image > img {height: 2000px !important; max-height: none !important;} "
) as iface:
gradio.Markdown(pathlib.Path("description.md").read_text)
with gradio.Row(equal_height=True):
with gradio.Column(scale=4):
sent = gradio.Textbox(label="Input sentence")
with gradio.Column(scale=1):
model_choice = gradio.Dropdown(choices=['bert', 'roberta'], value="bert")
but = gradio.Button("Submit")
gradio.Examples(examples, [sent])
with gradio.Row(equal_height=True):
with gradio.Column():
gradio.Markdown("### With Rollout")
rollout_result = gradio.Plot()
with gradio.Column():
gradio.Markdown("### Without Rollout")
value_result = gradio.Plot()
with gradio.Accordion("Some more details"):
gradio.Markdown(pathlib.Path("notice.md").read_text)
but.click(run,
inputs=[model_choice, sent],
outputs=[rollout_result, value_result]
)
iface.launch()