|
import torch |
|
import gradio as gr |
|
import plotly.express as px |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
|
|
|
|
model_name = "distilbert-base-uncased" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
model = AutoModel.from_pretrained(model_name, output_attentions=True) |
|
model.eval() |
|
|
|
def visualize_attention(text, layer=5): |
|
""" |
|
1. Tokenize input text. |
|
2. Run DistilBERT forward pass to get attention matrices. |
|
3. Pick a layer (0..5) and average across attention heads. |
|
4. Generate a heatmap (Plotly) of shape (seq_len x seq_len). |
|
5. Label axes with tokens (Query vs. Key). |
|
""" |
|
with torch.no_grad(): |
|
inputs = tokenizer.encode_plus(text, return_tensors="pt") |
|
outputs = model(**inputs) |
|
|
|
all_attentions = outputs.attentions |
|
|
|
attn_layer = all_attentions[layer].mean(dim=1) |
|
|
|
|
|
attn_matrix = attn_layer[0].cpu().numpy() |
|
|
|
|
|
input_ids = inputs["input_ids"][0] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
fig = px.imshow( |
|
attn_matrix, |
|
x=tokens, |
|
y=tokens, |
|
labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"}, |
|
color_continuous_scale="Blues", |
|
title=f"DistilBERT Attention (Layer {layer})" |
|
) |
|
fig.update_xaxes(side="top") |
|
|
|
|
|
fig.update_traces( |
|
hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}" |
|
) |
|
return fig |
|
|
|
|
|
description_text = """ |
|
## Understanding Transformer Self-Attention |
|
|
|
- **Rows = "Query token"** (the token that is looking at other tokens) |
|
- **Columns = "Key token"** (the token being looked at) |
|
- Darker (or higher) color = stronger attention. |
|
|
|
**Transformers** process all tokens in **parallel**, not step-by-step like RNNs. |
|
Thus, **long-distance dependencies** are easier to capture: any token can directly |
|
attend to any other token, regardless of distance in the sentence. |
|
""" |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)") |
|
gr.Markdown(description_text) |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Enter a sentence", |
|
value="Transformers handle long-range context in parallel." |
|
) |
|
layer_slider = gr.Slider( |
|
minimum=0, maximum=5, step=1, value=5, |
|
label="DistilBERT Layer (0=lowest, 5=highest)" |
|
) |
|
|
|
output_plot = gr.Plot(label="Attention Heatmap") |
|
|
|
visualize_button = gr.Button("Visualize Attention") |
|
visualize_button.click( |
|
fn=visualize_attention, |
|
inputs=[text_input, layer_slider], |
|
outputs=output_plot |
|
) |
|
|
|
demo.launch() |
|
|
|
|