|
import torch |
|
import gradio as gr |
|
import plotly.express as px |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
|
|
|
|
model_name = "distilbert-base-uncased" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name, output_attentions=True) |
|
model.eval() |
|
|
|
def visualize_attention(text, layer=5): |
|
""" |
|
1. Tokenize input text. |
|
2. Run DistilBERT forward pass to get attention matrices. |
|
3. Pick a layer (0..5) and average across attention heads. |
|
4. Generate a heatmap (Plotly) of shape (seq_len x seq_len). |
|
5. Label axes with tokens (Query vs. Key). |
|
""" |
|
with torch.no_grad(): |
|
inputs = tokenizer.encode_plus(text, return_tensors="pt") |
|
outputs = model(**inputs) |
|
all_attentions = outputs.attentions |
|
|
|
attn_layer = all_attentions[layer].mean(dim=1) |
|
|
|
|
|
attn_matrix = attn_layer[0].cpu().numpy() |
|
|
|
|
|
input_ids = inputs["input_ids"][0] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
fig = px.imshow( |
|
attn_matrix, |
|
x=tokens, |
|
y=tokens, |
|
labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"}, |
|
color_continuous_scale="Blues", |
|
title=f"DistilBERT Attention (Layer {layer})" |
|
) |
|
fig.update_xaxes(side="top") |
|
|
|
|
|
fig.update_traces( |
|
hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}" |
|
) |
|
fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight")) |
|
|
|
return fig |
|
|
|
def interpret_token_attention(text, token_index=0, layer=5): |
|
""" |
|
Provides a textual explanation for why a particular token (Query) attends |
|
to other tokens in the input, highlighting the top 2 or 3 tokens |
|
it focuses on. |
|
""" |
|
with torch.no_grad(): |
|
inputs = tokenizer.encode_plus(text, return_tensors="pt") |
|
outputs = model(**inputs) |
|
all_attentions = outputs.attentions |
|
attn_layer = all_attentions[layer].mean(dim=1) |
|
|
|
|
|
input_ids = inputs["input_ids"][0] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
if token_index < 0 or token_index >= len(tokens): |
|
return "Invalid token index. Please choose a valid token index." |
|
|
|
|
|
query_attn = attn_layer[0, token_index, :].cpu().numpy() |
|
|
|
|
|
sorted_indices = query_attn.argsort()[::-1] |
|
top_indices = sorted_indices[:3] |
|
top_tokens = [tokens[i] for i in top_indices] |
|
top_weights = [query_attn[i] for i in top_indices] |
|
|
|
|
|
query_token_str = tokens[token_index] |
|
explanation = ( |
|
f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n" |
|
"In Transformers, each token is converted into Query, Key, and Value vectors:\n" |
|
"- **Query** = What this token is looking for\n" |
|
"- **Key** = What another token has to offer\n" |
|
"- **Value** = The actual information from that token\n\n" |
|
f"As a Query, '{query_token_str}' attends most strongly to:\n" |
|
) |
|
|
|
for t, w in zip(top_tokens, top_weights): |
|
explanation += f"- **{t}** with attention weight ~ {w:.3f}\n" |
|
|
|
explanation += ( |
|
"\nA higher attention weight indicates that this Query token is 'looking at' or " |
|
"focusing on that Key token more strongly, likely because it finds the Key token " |
|
"relevant to its meaning or context." |
|
) |
|
|
|
return explanation |
|
|
|
|
|
description_text = """ |
|
## Understanding Transformer Self-Attention |
|
|
|
- **Rows = Query token** (the token doing the 'looking'). |
|
- **Columns = Key token** (the token being 'looked at'). |
|
- Darker color = stronger attention weight. |
|
|
|
**Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence. |
|
This makes it easier for the model to capture long-distance relationships. |
|
""" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(css="footer{display:none !important}") as demo: |
|
gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)") |
|
gr.Markdown(description_text) |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Enter a sentence", |
|
value="Transformers handle long-range context in parallel." |
|
) |
|
layer_slider = gr.Slider( |
|
minimum=0, maximum=5, step=1, value=5, |
|
label="DistilBERT Layer (0=lowest, 5=highest)" |
|
) |
|
output_plot = gr.Plot(label="Attention Heatmap") |
|
|
|
|
|
visualize_button = gr.Button("Visualize Attention") |
|
visualize_button.click( |
|
fn=visualize_attention, |
|
inputs=[text_input, layer_slider], |
|
outputs=output_plot |
|
) |
|
|
|
|
|
token_index = gr.Number( |
|
label="Choose a token index to interpret (0-based)", |
|
value=0 |
|
) |
|
|
|
interpretation_output = gr.Markdown(label="Interpretation") |
|
|
|
|
|
interpret_button = gr.Button("Explain This Token's Attention") |
|
interpret_button.click( |
|
fn=interpret_token_attention, |
|
inputs=[text_input, token_index, layer_slider], |
|
outputs=interpretation_output |
|
) |
|
|
|
demo.launch() |
|
|