viewviz / app.py
prdev's picture
Update app.py
8d47fe0 verified
raw
history blame contribute delete
No virus
4.8 kB
import streamlit as st
import streamlit_analytics
import torch
import torchvision.transforms as transforms
from transformers import ViTModel, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
streamlit_analytics.start_tracking()
# Set page config for custom theme
st.set_page_config(page_title="ViewViz", layout="wide")
# Custom color scheme for Streamlit
st.markdown("""
<style>
.stApp {
background-color: #2b3d4f;
color: #ffffff;
}
.stButton>button {
color: #2b3d4f;
background-color: #4fd1c5;
border-radius: 5px;
}
.stSlider>div>div>div>div {
background-color: #4fd1c5;
}
</style>
""", unsafe_allow_html=True)
# Set device preference
USE_GPU = False # Set to True to use GPU, False to use CPU
device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu')
# Available color schemes
COLOR_SCHEMES = {
'Plasma': plt.cm.plasma,
'Viridis': plt.cm.viridis,
'Magma': plt.cm.magma,
'Inferno': plt.cm.inferno,
'Cividis': plt.cm.cividis,
'Spectral': plt.cm.Spectral,
'Coolwarm': plt.cm.coolwarm
}
# Load the pre-trained Vision Transformer model
@st.cache_resource
def load_model():
model_name = 'google/vit-base-patch16-384'
config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager")
model = ViTModel.from_pretrained(model_name, config=config)
model.eval()
return model.to(device)
model = load_model()
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_attention_map(img):
# Preprocess the image
input_tensor = preprocess(img).unsqueeze(0).to(device)
# Get model output
with torch.no_grad():
outputs = model(input_tensor, output_attentions=True)
# Process attention maps
att_mat = torch.stack(outputs.attentions).squeeze(1)
att_mat = torch.mean(att_mat, dim=1)
# Add residual connections
residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device)
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size()).to(device)
joint_attentions[0] = aug_att_mat[0]
for n in range(1, aug_att_mat.size(0)):
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
# Get final attention map
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy()
return mask
def overlay_attention_map(image, attention_map, overlay_strength, color_scheme):
# Resize attention map to match image size
attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC)
attention_map = np.array(attention_map)
# Normalize attention map
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# Apply selected color map
attention_map_color = color_scheme(attention_map)
# Convert image to RGBA
image_rgba = image.convert("RGBA")
image_array = np.array(image_rgba) / 255.0
# Overlay attention map on image with adjustable strength
overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength
return Image.fromarray((overlayed_image * 255).astype(np.uint8))
st.title("ViewViz")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.success("Starting Prediction Process...")
attention_map = get_attention_map(image)
col1, col2 = st.columns(2)
with col1:
overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0
with col2:
color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys()))
color_scheme = COLOR_SCHEMES[color_scheme_name]
overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme)
st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True)
# Option to download the overlayed image
buf = io.BytesIO()
overlayed_image.save(buf, format="PNG")
btn = st.download_button(
label="Download Image with Attention Map",
data=buf.getvalue(),
file_name="attention_map_overlay.png",
mime="image/png"
)
streamlit_analytics.stop_tracking()