|
import streamlit as st |
|
import streamlit_analytics |
|
|
|
import torch |
|
import torchvision.transforms as transforms |
|
from transformers import ViTModel, ViTConfig |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import io |
|
|
|
streamlit_analytics.start_tracking() |
|
|
|
|
|
st.set_page_config(page_title="ViewViz", layout="wide") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stApp { |
|
background-color: #2b3d4f; |
|
color: #ffffff; |
|
} |
|
.stButton>button { |
|
color: #2b3d4f; |
|
background-color: #4fd1c5; |
|
border-radius: 5px; |
|
} |
|
.stSlider>div>div>div>div { |
|
background-color: #4fd1c5; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
USE_GPU = False |
|
device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
COLOR_SCHEMES = { |
|
'Plasma': plt.cm.plasma, |
|
'Viridis': plt.cm.viridis, |
|
'Magma': plt.cm.magma, |
|
'Inferno': plt.cm.inferno, |
|
'Cividis': plt.cm.cividis, |
|
'Spectral': plt.cm.Spectral, |
|
'Coolwarm': plt.cm.coolwarm |
|
} |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_name = 'google/vit-base-patch16-384' |
|
config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager") |
|
model = ViTModel.from_pretrained(model_name, config=config) |
|
model.eval() |
|
return model.to(device) |
|
|
|
model = load_model() |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((384, 384)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def get_attention_map(img): |
|
|
|
input_tensor = preprocess(img).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_tensor, output_attentions=True) |
|
|
|
|
|
att_mat = torch.stack(outputs.attentions).squeeze(1) |
|
att_mat = torch.mean(att_mat, dim=1) |
|
|
|
|
|
residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device) |
|
aug_att_mat = att_mat + residual_att |
|
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) |
|
|
|
|
|
joint_attentions = torch.zeros(aug_att_mat.size()).to(device) |
|
joint_attentions[0] = aug_att_mat[0] |
|
for n in range(1, aug_att_mat.size(0)): |
|
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) |
|
|
|
|
|
v = joint_attentions[-1] |
|
grid_size = int(np.sqrt(aug_att_mat.size(-1))) |
|
mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy() |
|
|
|
return mask |
|
|
|
def overlay_attention_map(image, attention_map, overlay_strength, color_scheme): |
|
|
|
attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC) |
|
attention_map = np.array(attention_map) |
|
|
|
|
|
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) |
|
|
|
|
|
attention_map_color = color_scheme(attention_map) |
|
|
|
|
|
image_rgba = image.convert("RGBA") |
|
image_array = np.array(image_rgba) / 255.0 |
|
|
|
|
|
overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength |
|
|
|
return Image.fromarray((overlayed_image * 255).astype(np.uint8)) |
|
|
|
st.title("ViewViz") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file).convert('RGB') |
|
|
|
st.success("Starting Prediction Process...") |
|
attention_map = get_attention_map(image) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0 |
|
|
|
with col2: |
|
color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys())) |
|
|
|
color_scheme = COLOR_SCHEMES[color_scheme_name] |
|
|
|
overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme) |
|
|
|
st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True) |
|
|
|
|
|
buf = io.BytesIO() |
|
overlayed_image.save(buf, format="PNG") |
|
btn = st.download_button( |
|
label="Download Image with Attention Map", |
|
data=buf.getvalue(), |
|
file_name="attention_map_overlay.png", |
|
mime="image/png" |
|
) |
|
|
|
streamlit_analytics.stop_tracking() |