import torch import gradio as gr import plotly.express as px from transformers import AutoModel, AutoTokenizer ######################################## # Load Transformer (DistilBERT) with attention ######################################## model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) # Note: output_attentions=True to extract attention matrices model = AutoModel.from_pretrained(model_name, output_attentions=True) model.eval() def visualize_attention(text, layer=5): """ 1. Tokenize input text. 2. Run DistilBERT forward pass to get attention matrices. 3. Pick a layer (0..5) and average across attention heads. 4. Generate a heatmap (Plotly) of shape (seq_len x seq_len). 5. Label axes with tokens (Query vs. Key). """ with torch.no_grad(): inputs = tokenizer.encode_plus(text, return_tensors="pt") outputs = model(**inputs) # outputs.attentions: tuple of shape [num_layers] each => (batch=1, num_heads, seq_len, seq_len) all_attentions = outputs.attentions # DistilBERT has 6 layers => valid indices: 0..5 attn_layer = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len) # Convert to numpy for plotting attn_matrix = attn_layer[0].cpu().numpy() # Get tokens (including special tokens like [CLS], [SEP]) input_ids = inputs["input_ids"][0] tokens = tokenizer.convert_ids_to_tokens(input_ids) # Build a Plotly heatmap fig = px.imshow( attn_matrix, x=tokens, y=tokens, labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"}, color_continuous_scale="Blues", title=f"DistilBERT Attention (Layer {layer})" ) fig.update_xaxes(side="top") # Add tooltip: shows row token, column token, and attention weight fig.update_traces( hovertemplate="Query: %{y}
Key: %{x}
Attention Weight: %{z:.3f}" ) return fig # Short explanation text for the UI description_text = """ ## Understanding Transformer Self-Attention - **Rows = "Query token"** (the token that is looking at other tokens) - **Columns = "Key token"** (the token being looked at) - Darker (or higher) color = stronger attention. **Transformers** process all tokens in **parallel**, not step-by-step like RNNs. Thus, **long-distance dependencies** are easier to capture: any token can directly attend to any other token, regardless of distance in the sentence. """ ######################################## # Gradio Interface ######################################## with gr.Blocks() as demo: gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)") gr.Markdown(description_text) with gr.Row(): text_input = gr.Textbox( label="Enter a sentence", value="Transformers handle long-range context in parallel." ) layer_slider = gr.Slider( minimum=0, maximum=5, step=1, value=5, label="DistilBERT Layer (0=lowest, 5=highest)" ) output_plot = gr.Plot(label="Attention Heatmap") visualize_button = gr.Button("Visualize Attention") visualize_button.click( fn=visualize_attention, inputs=[text_input, layer_slider], outputs=output_plot ) demo.launch()