kevin1911 commited on
Commit
da6e24a
·
verified ·
1 Parent(s): f9589f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -53
app.py CHANGED
@@ -1,94 +1,197 @@
1
  import torch
2
  import gradio as gr
3
  import plotly.express as px
 
4
  from transformers import AutoModel, AutoTokenizer
5
 
6
  ########################################
7
- # Load Transformer (DistilBERT) with attention
8
  ########################################
9
  model_name = "distilbert-base-uncased"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- # Note: output_attentions=True to extract attention matrices
12
  model = AutoModel.from_pretrained(model_name, output_attentions=True)
13
  model.eval()
14
 
15
- def visualize_attention(text, layer=5):
 
 
 
16
  """
17
- 1. Tokenize input text.
18
- 2. Run DistilBERT forward pass to get attention matrices.
19
- 3. Pick a layer (0..5) and average across attention heads.
20
- 4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
21
- 5. Label axes with tokens (Query vs. Key).
 
 
22
  """
 
23
  with torch.no_grad():
24
  inputs = tokenizer.encode_plus(text, return_tensors="pt")
25
  outputs = model(**inputs)
26
- # outputs.attentions: tuple of shape [num_layers] each => (batch=1, num_heads, seq_len, seq_len)
27
- all_attentions = outputs.attentions
28
- # DistilBERT has 6 layers => valid indices: 0..5
29
- attn_layer = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
30
-
31
- # Convert to numpy for plotting
32
- attn_matrix = attn_layer[0].cpu().numpy()
33
 
34
- # Get tokens (including special tokens like [CLS], [SEP])
35
  input_ids = inputs["input_ids"][0]
36
  tokens = tokenizer.convert_ids_to_tokens(input_ids)
 
37
 
38
- # Build a Plotly heatmap
39
- fig = px.imshow(
40
- attn_matrix,
41
- x=tokens,
42
- y=tokens,
43
- labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
44
- color_continuous_scale="Blues",
45
- title=f"DistilBERT Attention (Layer {layer})"
46
- )
47
- fig.update_xaxes(side="top")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Add tooltip: shows row token, column token, and attention weight
50
- fig.update_traces(
51
- hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
53
- return fig
 
54
 
55
- # Short explanation text for the UI
56
- description_text = """
57
- ## Understanding Transformer Self-Attention
 
58
 
59
- - **Rows = "Query token"** (the token that is looking at other tokens)
60
- - **Columns = "Key token"** (the token being looked at)
61
- - Darker (or higher) color = stronger attention.
62
 
63
- **Transformers** process all tokens in **parallel**, not step-by-step like RNNs.
64
- Thus, **long-distance dependencies** are easier to capture: any token can directly
65
- attend to any other token, regardless of distance in the sentence.
66
- """
67
 
68
  ########################################
69
- # Gradio Interface
70
  ########################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  with gr.Blocks() as demo:
72
- gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
73
- gr.Markdown(description_text)
74
 
75
  with gr.Row():
76
- text_input = gr.Textbox(
77
- label="Enter a sentence",
78
  value="Transformers handle long-range context in parallel."
79
  )
80
- layer_slider = gr.Slider(
81
  minimum=0, maximum=5, step=1, value=5,
82
- label="DistilBERT Layer (0=lowest, 5=highest)"
83
  )
 
 
 
 
 
 
 
 
 
84
 
85
- output_plot = gr.Plot(label="Attention Heatmap")
 
86
 
87
- visualize_button = gr.Button("Visualize Attention")
88
- visualize_button.click(
89
- fn=visualize_attention,
90
- inputs=[text_input, layer_slider],
91
- outputs=output_plot
92
  )
93
 
94
  demo.launch()
 
1
  import torch
2
  import gradio as gr
3
  import plotly.express as px
4
+ import numpy as np
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  ########################################
8
+ # 1) Load DistilBERT with attention
9
  ########################################
10
  model_name = "distilbert-base-uncased"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
  model = AutoModel.from_pretrained(model_name, output_attentions=True)
13
  model.eval()
14
 
15
+ ########################################
16
+ # 2) Generate attention analysis
17
+ ########################################
18
+ def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
19
  """
20
+ 1. Tokenize 'text'.
21
+ 2. Forward pass DistilBERT (output_attentions=True).
22
+ 3. Extract attention from chosen layer (0..5).
23
+ 4. Average across heads => (seq_len, seq_len).
24
+ 5. Optionally create Plotly heatmap (fig_dict).
25
+ 6. Create text summary of top-k focuses for each token.
26
+ 7. Generate an "interpretation" to highlight interesting patterns.
27
  """
28
+
29
  with torch.no_grad():
30
  inputs = tokenizer.encode_plus(text, return_tensors="pt")
31
  outputs = model(**inputs)
32
+ all_attentions = outputs.attentions # tuple: (#layers) each => (1, #heads, seq_len, seq_len)
33
+ # DistilBERT has 6 layers => valid range: 0..5
34
+ att = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
 
 
 
 
35
 
36
+ att_matrix = att[0].cpu().numpy() # (seq_len, seq_len)
37
  input_ids = inputs["input_ids"][0]
38
  tokens = tokenizer.convert_ids_to_tokens(input_ids)
39
+ seq_len = len(tokens)
40
 
41
+ # (Optional) Heatmap
42
+ fig_dict = None
43
+ if show_heatmap:
44
+ fig = px.imshow(
45
+ att_matrix,
46
+ x=tokens,
47
+ y=tokens,
48
+ labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
49
+ color_continuous_scale="Blues",
50
+ title=f"DistilBERT Self-Attention (Layer {layer})"
51
+ )
52
+ fig.update_xaxes(side="top")
53
+ fig.update_traces(
54
+ hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
55
+ )
56
+ fig_dict = fig.to_dict()
57
+
58
+ # Top-K Summary for each row
59
+ summary_md = "## Top-K Focus for Each Token\n"
60
+ summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
61
+ for i in range(seq_len):
62
+ row_token = tokens[i]
63
+ row_weights = att_matrix[i]
64
+ sorted_idx = row_weights.argsort()[::-1]
65
+ top_indices = sorted_idx[:top_k]
66
+
67
+ summary_md += f"**Token '{row_token}'** focuses on:\n"
68
+ for j in top_indices:
69
+ col_token = tokens[j]
70
+ weight = row_weights[j]
71
+ summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
72
+ summary_md += "\n"
73
+
74
+ # Generate an additional "interpretation" to highlight patterns
75
+ interpretation_md = interpret_attention(att_matrix, tokens)
76
+
77
+ # Combine summaries
78
+ combined_md = summary_md + "\n" + interpretation_md
79
+
80
+ return fig_dict, combined_md
81
+
82
+
83
+ ########################################
84
+ # 3) Interpretation function
85
+ ########################################
86
+ def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
87
+ """
88
+ Provide a short bullet-list interpretation of the attention matrix:
89
+ - Count how many tokens mostly attend to themselves (diagonal).
90
+ - Find the global max attention weight (which row->col?), mention tokens involved.
91
+ - Possibly mention if we see something interesting about distribution.
92
+ """
93
+
94
+ seq_len = len(tokens)
95
+ diagonal_focus_count = 0
96
+ # We'll track the max weight overall
97
+ max_val = -1.0
98
+ max_rc = (0, 0)
99
+
100
+ # For each row, check if diagonal is the top focus
101
+ for i in range(seq_len):
102
+ row = att_matrix[i]
103
+ best_j = row.argmax()
104
+ if best_j == i:
105
+ diagonal_focus_count += 1
106
+ # Check global max
107
+ if row[best_j] > max_val:
108
+ max_val = row[best_j]
109
+ max_rc = (i, best_j)
110
 
111
+ # Summaries
112
+ # 1) Diagonal focus stat
113
+ diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)."
114
+
115
+ # 2) Global max
116
+ i, j = max_rc
117
+ token_i = tokens[i]
118
+ token_j = tokens[j]
119
+ global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."
120
+
121
+ # 3) Possibly some quick ratio
122
+ # For each row, sum of row vs. sum of diagonal
123
+ # We'll keep it simpler for now
124
+
125
+ interpretation = "## Additional Interpretation\n\n"
126
+ interpretation += (
127
+ "Here are some overall patterns in the attention matrix that might help you:\n\n"
128
  )
129
+ interpretation += f"{diag_msg}\n"
130
+ interpretation += f"{global_msg}\n"
131
 
132
+ interpretation += "\n- A strong diagonal means tokens often reference themselves.\n"
133
+ interpretation += (
134
+ "- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n"
135
+ )
136
 
137
+ return interpretation
 
 
138
 
 
 
 
 
139
 
140
  ########################################
141
+ # 4) Gradio UI
142
  ########################################
143
+ description_md = """
144
+ # DistilBERT Self-Attention with Extra Interpretation
145
+
146
+ **Instructions:**
147
+ 1. Type your text into the box.
148
+ 2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
149
+ 3. Decide how many top tokens you want listed for each token (Top-K).
150
+ 4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.
151
+
152
+ **Reading the Heatmap**:
153
+ - **Rows** = tokens doing the looking (focus).
154
+ - **Columns** = tokens being looked at.
155
+ - **Color intensity** = how strongly the row token focuses on the column token.
156
+
157
+ Below the heatmap, you'll see:
158
+ - A **Top-K focus** summary for each token.
159
+ - An **interpretation** bullet list that highlights interesting overall patterns.
160
+ """
161
+
162
+ def run_demo(text, layer, top_k, show_heatmap):
163
+ fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap)
164
+ return fig_dict, summary_md
165
+
166
  with gr.Blocks() as demo:
167
+ gr.Markdown(description_md)
 
168
 
169
  with gr.Row():
170
+ text_in = gr.Textbox(
171
+ label="Enter text",
172
  value="Transformers handle long-range context in parallel."
173
  )
174
+ layer_in = gr.Slider(
175
  minimum=0, maximum=5, step=1, value=5,
176
+ label="DistilBERT Layer"
177
  )
178
+ topk_in = gr.Slider(
179
+ minimum=1, maximum=6, step=1, value=3,
180
+ label="Top-K Focus"
181
+ )
182
+ show_heatmap_check = gr.Checkbox(
183
+ label="Show Heatmap?",
184
+ value=True
185
+ )
186
+ run_btn = gr.Button("Analyze Attention")
187
 
188
+ out_plot = gr.Plot(label="Attention Heatmap")
189
+ out_summary = gr.Markdown(label="Attention Summaries & Interpretation")
190
 
191
+ run_btn.click(
192
+ fn=run_demo,
193
+ inputs=[text_in, layer_in, topk_in, show_heatmap_check],
194
+ outputs=[out_plot, out_summary]
 
195
  )
196
 
197
  demo.launch()