File size: 3,339 Bytes
0b23237
1f47b32
0b23237
 
1f47b32
0b23237
 
 
 
 
 
 
 
1f47b32
0b23237
1f47b32
0b23237
 
 
 
 
1f47b32
0b23237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f47b32
0b23237
1f47b32
0b23237
 
 
1f47b32
0b23237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f47b32
 
0b23237
 
 
 
 
 
 
 
1f47b32
0b23237
1f47b32
0b23237
 
 
 
 
1f47b32
 
 
0b23237
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import gradio as gr
import plotly.express as px
from transformers import AutoModel, AutoTokenizer

########################################
# Load Transformer (DistilBERT) with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Note: output_attentions=True to extract attention matrices
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()

def visualize_attention(text, layer=5):
    """
    1. Tokenize input text.
    2. Run DistilBERT forward pass to get attention matrices.
    3. Pick a layer (0..5) and average across attention heads.
    4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
    5. Label axes with tokens (Query vs. Key).
    """
    with torch.no_grad():
        inputs = tokenizer.encode_plus(text, return_tensors="pt")
        outputs = model(**inputs)
        # outputs.attentions: tuple of shape [num_layers] each => (batch=1, num_heads, seq_len, seq_len)
        all_attentions = outputs.attentions
        # DistilBERT has 6 layers => valid indices: 0..5
        attn_layer = all_attentions[layer].mean(dim=1)  # average across heads => shape: (1, seq_len, seq_len)

    # Convert to numpy for plotting
    attn_matrix = attn_layer[0].cpu().numpy()

    # Get tokens (including special tokens like [CLS], [SEP])
    input_ids = inputs["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Build a Plotly heatmap
    fig = px.imshow(
        attn_matrix,
        x=tokens,
        y=tokens,
        labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
        color_continuous_scale="Blues",
        title=f"DistilBERT Attention (Layer {layer})"
    )
    fig.update_xaxes(side="top")

    # Add tooltip: shows row token, column token, and attention weight
    fig.update_traces(
        hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
    )
    return fig

# Short explanation text for the UI
description_text = """
## Understanding Transformer Self-Attention

- **Rows = "Query token"** (the token that is looking at other tokens)
- **Columns = "Key token"** (the token being looked at)
- Darker (or higher) color = stronger attention.

**Transformers** process all tokens in **parallel**, not step-by-step like RNNs.
Thus, **long-distance dependencies** are easier to capture: any token can directly
attend to any other token, regardless of distance in the sentence.
"""

########################################
# Gradio Interface
########################################
with gr.Blocks() as demo:
    gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
    gr.Markdown(description_text)

    with gr.Row():
        text_input = gr.Textbox(
            label="Enter a sentence",
            value="Transformers handle long-range context in parallel."
        )
        layer_slider = gr.Slider(
            minimum=0, maximum=5, step=1, value=5,
            label="DistilBERT Layer (0=lowest, 5=highest)"
        )

    output_plot = gr.Plot(label="Attention Heatmap")

    visualize_button = gr.Button("Visualize Attention")
    visualize_button.click(
        fn=visualize_attention,
        inputs=[text_input, layer_slider],
        outputs=output_plot
    )

demo.launch()