File size: 4,802 Bytes
67fcf26 8d47fe0 67fcf26 8d47fe0 67fcf26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import streamlit as st
import streamlit_analytics
import torch
import torchvision.transforms as transforms
from transformers import ViTModel, ViTConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io
streamlit_analytics.start_tracking()
# Set page config for custom theme
st.set_page_config(page_title="ViewViz", layout="wide")
# Custom color scheme for Streamlit
st.markdown("""
<style>
.stApp {
background-color: #2b3d4f;
color: #ffffff;
}
.stButton>button {
color: #2b3d4f;
background-color: #4fd1c5;
border-radius: 5px;
}
.stSlider>div>div>div>div {
background-color: #4fd1c5;
}
</style>
""", unsafe_allow_html=True)
# Set device preference
USE_GPU = False # Set to True to use GPU, False to use CPU
device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu')
# Available color schemes
COLOR_SCHEMES = {
'Plasma': plt.cm.plasma,
'Viridis': plt.cm.viridis,
'Magma': plt.cm.magma,
'Inferno': plt.cm.inferno,
'Cividis': plt.cm.cividis,
'Spectral': plt.cm.Spectral,
'Coolwarm': plt.cm.coolwarm
}
# Load the pre-trained Vision Transformer model
@st.cache_resource
def load_model():
model_name = 'google/vit-base-patch16-384'
config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager")
model = ViTModel.from_pretrained(model_name, config=config)
model.eval()
return model.to(device)
model = load_model()
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_attention_map(img):
# Preprocess the image
input_tensor = preprocess(img).unsqueeze(0).to(device)
# Get model output
with torch.no_grad():
outputs = model(input_tensor, output_attentions=True)
# Process attention maps
att_mat = torch.stack(outputs.attentions).squeeze(1)
att_mat = torch.mean(att_mat, dim=1)
# Add residual connections
residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device)
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size()).to(device)
joint_attentions[0] = aug_att_mat[0]
for n in range(1, aug_att_mat.size(0)):
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
# Get final attention map
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy()
return mask
def overlay_attention_map(image, attention_map, overlay_strength, color_scheme):
# Resize attention map to match image size
attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC)
attention_map = np.array(attention_map)
# Normalize attention map
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# Apply selected color map
attention_map_color = color_scheme(attention_map)
# Convert image to RGBA
image_rgba = image.convert("RGBA")
image_array = np.array(image_rgba) / 255.0
# Overlay attention map on image with adjustable strength
overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength
return Image.fromarray((overlayed_image * 255).astype(np.uint8))
st.title("ViewViz")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.success("Starting Prediction Process...")
attention_map = get_attention_map(image)
col1, col2 = st.columns(2)
with col1:
overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0
with col2:
color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys()))
color_scheme = COLOR_SCHEMES[color_scheme_name]
overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme)
st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True)
# Option to download the overlayed image
buf = io.BytesIO()
overlayed_image.save(buf, format="PNG")
btn = st.download_button(
label="Download Image with Attention Map",
data=buf.getvalue(),
file_name="attention_map_overlay.png",
mime="image/png"
)
streamlit_analytics.stop_tracking() |