import streamlit as st
import matplotlib.pyplot as plt
import pandas as pd
import torch
from transformers import AutoConfig, AutoTokenizer

# Page configuration
st.set_page_config(
    page_title="Transformer Visualizer",
    page_icon="🧠",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS styling
st.markdown("""
<style>
    .reportview-container {
        background: linear-gradient(45deg, #1a1a1a, #4a4a4a);
    }
    .sidebar .sidebar-content {
        background: #2c2c2c !important;
    }
    h1, h2, h3, h4, h5, h6 {
        color: #00ff00 !important;
    }
    .stMetric {
        background-color: #333333;
        border-radius: 10px;
        padding: 15px;
    }
    .architecture {
        font-family: monospace;
        color: #00ff00;
        white-space: pre-wrap;
        background-color: #1a1a1a;
        padding: 20px;
        border-radius: 10px;
        border: 1px solid #00ff00;
    }
    .token-table {
        margin-top: 20px;
        border: 1px solid #00ff00;
        border-radius: 5px;
    }
</style>
""", unsafe_allow_html=True)

# Model database
MODELS = {
    "BERT": {"model_name": "bert-base-uncased", "type": "Encoder", "layers": 12, "heads": 12, "params": 109.48},
    "GPT-2": {"model_name": "gpt2", "type": "Decoder", "layers": 12, "heads": 12, "params": 117},
    "T5-Small": {"model_name": "t5-small", "type": "Seq2Seq", "layers": 6, "heads": 8, "params": 60},
    "RoBERTa": {"model_name": "roberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 125},
    "DistilBERT": {"model_name": "distilbert-base-uncased", "type": "Encoder", "layers": 6, "heads": 12, "params": 66},
    "ALBERT": {"model_name": "albert-base-v2", "type": "Encoder", "layers": 12, "heads": 12, "params": 11.8},
    "ELECTRA": {"model_name": "google/electra-small-discriminator", "type": "Encoder", "layers": 12, "heads": 12, "params": 13.5},
    "XLNet": {"model_name": "xlnet-base-cased", "type": "AutoRegressive", "layers": 12, "heads": 12, "params": 110},
    "BART": {"model_name": "facebook/bart-base", "type": "Seq2Seq", "layers": 6, "heads": 16, "params": 139},
    "DeBERTa": {"model_name": "microsoft/deberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 139}
}

def get_model_config(model_name):
    config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"])
    return config

def plot_model_comparison(selected_model):
    model_names = list(MODELS.keys())
    params = [m["params"] for m in MODELS.values()]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    bars = ax.bar(model_names, params)
    
    index = list(MODELS.keys()).index(selected_model)
    bars[index].set_color('#00ff00')
    
    ax.set_ylabel('Parameters (Millions)', color='white')
    ax.set_title('Model Size Comparison', color='white')
    ax.tick_params(axis='x', rotation=45, colors='white')
    ax.tick_params(axis='y', colors='white')
    ax.set_facecolor('#2c2c2c')
    fig.patch.set_facecolor('#2c2c2c')
    
    st.pyplot(fig)

def visualize_architecture(model_info):
    architecture = []
    model_type = model_info["type"]
    layers = model_info.get("layers", model_info.get("layers", 12))  # Handle key variations
    heads = model_info["heads"]
    
    architecture.append("Input")
    architecture.append("│")
    architecture.append("▼")
    
    if model_type == "Encoder":
        architecture.append("[Embedding Layer]")
        for i in range(layers):
            architecture.extend([
                f"Encoder Layer {i+1}",
                "├─ Multi-Head Attention",
                f"│  └─ {heads} Heads",
                "├─ Layer Normalization",
                "└─ Feed Forward Network",
                "│",
                "▼"
            ])
        architecture.append("[Output]")
    
    elif model_type == "Decoder":
        architecture.append("[Embedding Layer]")
        for i in range(layers):
            architecture.extend([
                f"Decoder Layer {i+1}",
                "├─ Masked Multi-Head Attention",
                f"│  └─ {heads} Heads",
                "├─ Layer Normalization",
                "└─ Feed Forward Network",
                "│",
                "▼"
            ])
        architecture.append("[Output]")
    
    elif model_type == "Seq2Seq":
        architecture.append("Encoder Stack")
        for i in range(layers):
            architecture.extend([
                f"Encoder Layer {i+1}",
                "├─ Self-Attention",
                "└─ Feed Forward Network",
                "│",
                "▼"
            ])
        architecture.append("→→→ [Context] →→→")
        architecture.append("Decoder Stack")
        for i in range(layers):
            architecture.extend([
                f"Decoder Layer {i+1}",
                "├─ Masked Self-Attention",
                "├─ Encoder-Decoder Attention",
                "└─ Feed Forward Network",
                "│",
                "▼"
            ])
        architecture.append("[Output]")
    
    return "\n".join(architecture)

def visualize_attention_patterns():
    fig, ax = plt.subplots(figsize=(8, 6))
    data = torch.randn(5, 5)
    ax.imshow(data, cmap='viridis')
    ax.set_title('Attention Patterns Example', color='white')
    ax.set_facecolor('#2c2c2c')
    fig.patch.set_facecolor('#2c2c2c')
    st.pyplot(fig)

def get_hardware_recommendation(params):
    if params < 100:
        return "CPU or Entry-level GPU (e.g., GTX 1060)"
    elif 100 <= params < 200:
        return "Mid-range GPU (e.g., RTX 2080, RTX 3060)"
    else:
        return "High-end GPU (e.g., RTX 3090, A100) or TPU"

def main():
    st.title("🧠 Transformer Model Visualizer")
    
    selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
    model_info = MODELS[selected_model]
    config = get_model_config(selected_model)
    tokenizer = AutoTokenizer.from_pretrained(model_info["model_name"])
    
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        st.metric("Model Type", model_info["type"])
    with col2:
        st.metric("Layers", model_info.get("layers", model_info.get("layers", "N/A")))
    with col3:
        st.metric("Attention Heads", model_info["heads"])
    with col4:
        st.metric("Parameters", f"{model_info['params']}M")
    
    tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
        "Model Structure", "Comparison", "Model Attention", 
        "Tokenization", "Hardware", "Memory"
    ])
    
    with tab1:
        st.subheader("Architecture Diagram")
        architecture = visualize_architecture(model_info)
        st.markdown(f"<div class='architecture'>{architecture}</div>", unsafe_allow_html=True)
        
        st.markdown("""
        **Legend:**
        - **Multi-Head Attention**: Self-attention mechanism with multiple parallel heads
        - **Layer Normalization**: Normalization operation between layers
        - **Feed Forward Network**: Position-wise fully connected network
        - **Masked Attention**: Attention with future token masking
        """)
    
    with tab2:
        st.subheader("Model Size Comparison")
        plot_model_comparison(selected_model)
    
    with tab3:
        st.subheader("Model-specific Visualizations")
        visualize_attention_patterns()
    
    with tab4:
        st.subheader("📝 Tokenization Visualization")
        input_text = st.text_input("Enter Text:", "Hello, how are you?")
        
        col1, col2 = st.columns(2)
        with col1:
            st.markdown("**Tokenized Output**")
            tokens = tokenizer.tokenize(input_text)
            st.write(tokens)
        with col2:
            st.markdown("**Token IDs**")
            encoded_ids = tokenizer.encode(input_text)
            st.write(encoded_ids)
        
        st.markdown("**Token-ID Mapping**")
        token_data = pd.DataFrame({
            "Token": tokens,
            "ID": encoded_ids[1:-1] if tokenizer.cls_token else encoded_ids
        })
        st.dataframe(token_data, height=150, use_container_width=True)
        
        st.markdown(f"""
        **Tokenizer Info:**
        - Vocabulary size: `{tokenizer.vocab_size}`
        - Special tokens: `{tokenizer.all_special_tokens}`
        - Padding token: `{tokenizer.pad_token}`
        - Max length: `{tokenizer.model_max_length}`
        """)

    with tab5:
        st.subheader("🖥️ Hardware Recommendation")
        params = model_info["params"]
        recommendation = get_hardware_recommendation(params)
        
        st.markdown(f"**Recommended hardware for {selected_model}:**")
        st.info(recommendation)
        
        st.markdown("""
        **Recommendation Criteria:**
        - <100M parameters: Suitable for CPU or entry-level GPUs
        - 100-200M parameters: Requires mid-range GPUs
        - >200M parameters: Needs high-end GPUs/TPUs
        """)

    with tab6:
        st.subheader("💾 Memory Usage Estimation")
        params = model_info["params"]
        memory_mb = params * 4  # 1M params ≈ 4MB in FP32
        memory_gb = memory_mb / 1024
        
        st.metric("Estimated Memory (FP32)", 
                 f"{memory_mb:.1f} MB / {memory_gb:.2f} GB")
        
        st.markdown("""
        **Memory Notes:**
        - Based on 4 bytes per parameter (FP32 precision)
        - Actual usage varies with:
          - Batch size
          - Sequence length
          - Precision (FP16/FP32)
          - Optimizer states (training)
        """)

if __name__ == "__main__":
    main()