import streamlit as st import streamlit_analytics import torch import torchvision.transforms as transforms from transformers import ViTModel, ViTConfig from PIL import Image import numpy as np import matplotlib.pyplot as plt import io streamlit_analytics.start_tracking() # Set page config for custom theme st.set_page_config(page_title="ViewViz", layout="wide") # Custom color scheme for Streamlit st.markdown(""" """, unsafe_allow_html=True) # Set device preference USE_GPU = False # Set to True to use GPU, False to use CPU device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu') # Available color schemes COLOR_SCHEMES = { 'Plasma': plt.cm.plasma, 'Viridis': plt.cm.viridis, 'Magma': plt.cm.magma, 'Inferno': plt.cm.inferno, 'Cividis': plt.cm.cividis, 'Spectral': plt.cm.Spectral, 'Coolwarm': plt.cm.coolwarm } # Load the pre-trained Vision Transformer model @st.cache_resource def load_model(): model_name = 'google/vit-base-patch16-384' config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager") model = ViTModel.from_pretrained(model_name, config=config) model.eval() return model.to(device) model = load_model() # Image preprocessing preprocess = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def get_attention_map(img): # Preprocess the image input_tensor = preprocess(img).unsqueeze(0).to(device) # Get model output with torch.no_grad(): outputs = model(input_tensor, output_attentions=True) # Process attention maps att_mat = torch.stack(outputs.attentions).squeeze(1) att_mat = torch.mean(att_mat, dim=1) # Add residual connections residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device) aug_att_mat = att_mat + residual_att aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) # Recursively multiply the weight matrices joint_attentions = torch.zeros(aug_att_mat.size()).to(device) joint_attentions[0] = aug_att_mat[0] for n in range(1, aug_att_mat.size(0)): joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) # Get final attention map v = joint_attentions[-1] grid_size = int(np.sqrt(aug_att_mat.size(-1))) mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy() return mask def overlay_attention_map(image, attention_map, overlay_strength, color_scheme): # Resize attention map to match image size attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC) attention_map = np.array(attention_map) # Normalize attention map attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) # Apply selected color map attention_map_color = color_scheme(attention_map) # Convert image to RGBA image_rgba = image.convert("RGBA") image_array = np.array(image_rgba) / 255.0 # Overlay attention map on image with adjustable strength overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength return Image.fromarray((overlayed_image * 255).astype(np.uint8)) st.title("ViewViz") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') st.success("Starting Prediction Process...") attention_map = get_attention_map(image) col1, col2 = st.columns(2) with col1: overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0 with col2: color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys())) color_scheme = COLOR_SCHEMES[color_scheme_name] overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme) st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True) # Option to download the overlayed image buf = io.BytesIO() overlayed_image.save(buf, format="PNG") btn = st.download_button( label="Download Image with Attention Map", data=buf.getvalue(), file_name="attention_map_overlay.png", mime="image/png" ) streamlit_analytics.stop_tracking()