Update app.py
Browse files
app.py
CHANGED
@@ -1,94 +1,197 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import plotly.express as px
|
|
|
4 |
from transformers import AutoModel, AutoTokenizer
|
5 |
|
6 |
########################################
|
7 |
-
# Load
|
8 |
########################################
|
9 |
model_name = "distilbert-base-uncased"
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
-
# Note: output_attentions=True to extract attention matrices
|
12 |
model = AutoModel.from_pretrained(model_name, output_attentions=True)
|
13 |
model.eval()
|
14 |
|
15 |
-
|
|
|
|
|
|
|
16 |
"""
|
17 |
-
1. Tokenize
|
18 |
-
2.
|
19 |
-
3.
|
20 |
-
4.
|
21 |
-
5.
|
|
|
|
|
22 |
"""
|
|
|
23 |
with torch.no_grad():
|
24 |
inputs = tokenizer.encode_plus(text, return_tensors="pt")
|
25 |
outputs = model(**inputs)
|
26 |
-
|
27 |
-
|
28 |
-
#
|
29 |
-
attn_layer = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
|
30 |
-
|
31 |
-
# Convert to numpy for plotting
|
32 |
-
attn_matrix = attn_layer[0].cpu().numpy()
|
33 |
|
34 |
-
|
35 |
input_ids = inputs["input_ids"][0]
|
36 |
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
|
|
37 |
|
38 |
-
#
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
#
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
-
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
-
|
60 |
-
- **Columns = "Key token"** (the token being looked at)
|
61 |
-
- Darker (or higher) color = stronger attention.
|
62 |
|
63 |
-
**Transformers** process all tokens in **parallel**, not step-by-step like RNNs.
|
64 |
-
Thus, **long-distance dependencies** are easier to capture: any token can directly
|
65 |
-
attend to any other token, regardless of distance in the sentence.
|
66 |
-
"""
|
67 |
|
68 |
########################################
|
69 |
-
# Gradio
|
70 |
########################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
with gr.Blocks() as demo:
|
72 |
-
gr.Markdown(
|
73 |
-
gr.Markdown(description_text)
|
74 |
|
75 |
with gr.Row():
|
76 |
-
|
77 |
-
label="Enter
|
78 |
value="Transformers handle long-range context in parallel."
|
79 |
)
|
80 |
-
|
81 |
minimum=0, maximum=5, step=1, value=5,
|
82 |
-
label="DistilBERT Layer
|
83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
outputs=output_plot
|
92 |
)
|
93 |
|
94 |
demo.launch()
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import plotly.express as px
|
4 |
+
import numpy as np
|
5 |
from transformers import AutoModel, AutoTokenizer
|
6 |
|
7 |
########################################
|
8 |
+
# 1) Load DistilBERT with attention
|
9 |
########################################
|
10 |
model_name = "distilbert-base-uncased"
|
11 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
12 |
model = AutoModel.from_pretrained(model_name, output_attentions=True)
|
13 |
model.eval()
|
14 |
|
15 |
+
########################################
|
16 |
+
# 2) Generate attention analysis
|
17 |
+
########################################
|
18 |
+
def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
|
19 |
"""
|
20 |
+
1. Tokenize 'text'.
|
21 |
+
2. Forward pass DistilBERT (output_attentions=True).
|
22 |
+
3. Extract attention from chosen layer (0..5).
|
23 |
+
4. Average across heads => (seq_len, seq_len).
|
24 |
+
5. Optionally create Plotly heatmap (fig_dict).
|
25 |
+
6. Create text summary of top-k focuses for each token.
|
26 |
+
7. Generate an "interpretation" to highlight interesting patterns.
|
27 |
"""
|
28 |
+
|
29 |
with torch.no_grad():
|
30 |
inputs = tokenizer.encode_plus(text, return_tensors="pt")
|
31 |
outputs = model(**inputs)
|
32 |
+
all_attentions = outputs.attentions # tuple: (#layers) each => (1, #heads, seq_len, seq_len)
|
33 |
+
# DistilBERT has 6 layers => valid range: 0..5
|
34 |
+
att = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
att_matrix = att[0].cpu().numpy() # (seq_len, seq_len)
|
37 |
input_ids = inputs["input_ids"][0]
|
38 |
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
39 |
+
seq_len = len(tokens)
|
40 |
|
41 |
+
# (Optional) Heatmap
|
42 |
+
fig_dict = None
|
43 |
+
if show_heatmap:
|
44 |
+
fig = px.imshow(
|
45 |
+
att_matrix,
|
46 |
+
x=tokens,
|
47 |
+
y=tokens,
|
48 |
+
labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
|
49 |
+
color_continuous_scale="Blues",
|
50 |
+
title=f"DistilBERT Self-Attention (Layer {layer})"
|
51 |
+
)
|
52 |
+
fig.update_xaxes(side="top")
|
53 |
+
fig.update_traces(
|
54 |
+
hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
|
55 |
+
)
|
56 |
+
fig_dict = fig.to_dict()
|
57 |
+
|
58 |
+
# Top-K Summary for each row
|
59 |
+
summary_md = "## Top-K Focus for Each Token\n"
|
60 |
+
summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
|
61 |
+
for i in range(seq_len):
|
62 |
+
row_token = tokens[i]
|
63 |
+
row_weights = att_matrix[i]
|
64 |
+
sorted_idx = row_weights.argsort()[::-1]
|
65 |
+
top_indices = sorted_idx[:top_k]
|
66 |
+
|
67 |
+
summary_md += f"**Token '{row_token}'** focuses on:\n"
|
68 |
+
for j in top_indices:
|
69 |
+
col_token = tokens[j]
|
70 |
+
weight = row_weights[j]
|
71 |
+
summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
|
72 |
+
summary_md += "\n"
|
73 |
+
|
74 |
+
# Generate an additional "interpretation" to highlight patterns
|
75 |
+
interpretation_md = interpret_attention(att_matrix, tokens)
|
76 |
+
|
77 |
+
# Combine summaries
|
78 |
+
combined_md = summary_md + "\n" + interpretation_md
|
79 |
+
|
80 |
+
return fig_dict, combined_md
|
81 |
+
|
82 |
+
|
83 |
+
########################################
|
84 |
+
# 3) Interpretation function
|
85 |
+
########################################
|
86 |
+
def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
|
87 |
+
"""
|
88 |
+
Provide a short bullet-list interpretation of the attention matrix:
|
89 |
+
- Count how many tokens mostly attend to themselves (diagonal).
|
90 |
+
- Find the global max attention weight (which row->col?), mention tokens involved.
|
91 |
+
- Possibly mention if we see something interesting about distribution.
|
92 |
+
"""
|
93 |
+
|
94 |
+
seq_len = len(tokens)
|
95 |
+
diagonal_focus_count = 0
|
96 |
+
# We'll track the max weight overall
|
97 |
+
max_val = -1.0
|
98 |
+
max_rc = (0, 0)
|
99 |
+
|
100 |
+
# For each row, check if diagonal is the top focus
|
101 |
+
for i in range(seq_len):
|
102 |
+
row = att_matrix[i]
|
103 |
+
best_j = row.argmax()
|
104 |
+
if best_j == i:
|
105 |
+
diagonal_focus_count += 1
|
106 |
+
# Check global max
|
107 |
+
if row[best_j] > max_val:
|
108 |
+
max_val = row[best_j]
|
109 |
+
max_rc = (i, best_j)
|
110 |
|
111 |
+
# Summaries
|
112 |
+
# 1) Diagonal focus stat
|
113 |
+
diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)."
|
114 |
+
|
115 |
+
# 2) Global max
|
116 |
+
i, j = max_rc
|
117 |
+
token_i = tokens[i]
|
118 |
+
token_j = tokens[j]
|
119 |
+
global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."
|
120 |
+
|
121 |
+
# 3) Possibly some quick ratio
|
122 |
+
# For each row, sum of row vs. sum of diagonal
|
123 |
+
# We'll keep it simpler for now
|
124 |
+
|
125 |
+
interpretation = "## Additional Interpretation\n\n"
|
126 |
+
interpretation += (
|
127 |
+
"Here are some overall patterns in the attention matrix that might help you:\n\n"
|
128 |
)
|
129 |
+
interpretation += f"{diag_msg}\n"
|
130 |
+
interpretation += f"{global_msg}\n"
|
131 |
|
132 |
+
interpretation += "\n- A strong diagonal means tokens often reference themselves.\n"
|
133 |
+
interpretation += (
|
134 |
+
"- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n"
|
135 |
+
)
|
136 |
|
137 |
+
return interpretation
|
|
|
|
|
138 |
|
|
|
|
|
|
|
|
|
139 |
|
140 |
########################################
|
141 |
+
# 4) Gradio UI
|
142 |
########################################
|
143 |
+
description_md = """
|
144 |
+
# DistilBERT Self-Attention with Extra Interpretation
|
145 |
+
|
146 |
+
**Instructions:**
|
147 |
+
1. Type your text into the box.
|
148 |
+
2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
|
149 |
+
3. Decide how many top tokens you want listed for each token (Top-K).
|
150 |
+
4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.
|
151 |
+
|
152 |
+
**Reading the Heatmap**:
|
153 |
+
- **Rows** = tokens doing the looking (focus).
|
154 |
+
- **Columns** = tokens being looked at.
|
155 |
+
- **Color intensity** = how strongly the row token focuses on the column token.
|
156 |
+
|
157 |
+
Below the heatmap, you'll see:
|
158 |
+
- A **Top-K focus** summary for each token.
|
159 |
+
- An **interpretation** bullet list that highlights interesting overall patterns.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def run_demo(text, layer, top_k, show_heatmap):
|
163 |
+
fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap)
|
164 |
+
return fig_dict, summary_md
|
165 |
+
|
166 |
with gr.Blocks() as demo:
|
167 |
+
gr.Markdown(description_md)
|
|
|
168 |
|
169 |
with gr.Row():
|
170 |
+
text_in = gr.Textbox(
|
171 |
+
label="Enter text",
|
172 |
value="Transformers handle long-range context in parallel."
|
173 |
)
|
174 |
+
layer_in = gr.Slider(
|
175 |
minimum=0, maximum=5, step=1, value=5,
|
176 |
+
label="DistilBERT Layer"
|
177 |
)
|
178 |
+
topk_in = gr.Slider(
|
179 |
+
minimum=1, maximum=6, step=1, value=3,
|
180 |
+
label="Top-K Focus"
|
181 |
+
)
|
182 |
+
show_heatmap_check = gr.Checkbox(
|
183 |
+
label="Show Heatmap?",
|
184 |
+
value=True
|
185 |
+
)
|
186 |
+
run_btn = gr.Button("Analyze Attention")
|
187 |
|
188 |
+
out_plot = gr.Plot(label="Attention Heatmap")
|
189 |
+
out_summary = gr.Markdown(label="Attention Summaries & Interpretation")
|
190 |
|
191 |
+
run_btn.click(
|
192 |
+
fn=run_demo,
|
193 |
+
inputs=[text_in, layer_in, topk_in, show_heatmap_check],
|
194 |
+
outputs=[out_plot, out_summary]
|
|
|
195 |
)
|
196 |
|
197 |
demo.launch()
|