prdev commited on
Commit
67fcf26
·
verified ·
1 Parent(s): decfb2f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit_analytics
3
+
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from transformers import ViTModel, ViTConfig
7
+ from PIL import Image
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+
12
+ streamlit_analytics.start_tracking()
13
+
14
+ # Set page config for custom theme
15
+ st.set_page_config(page_title="Where will they look?", layout="wide")
16
+
17
+ # Custom color scheme for Streamlit
18
+ st.markdown("""
19
+ <style>
20
+ .stApp {
21
+ background-color: #2b3d4f;
22
+ color: #ffffff;
23
+ }
24
+ .stButton>button {
25
+ color: #2b3d4f;
26
+ background-color: #4fd1c5;
27
+ border-radius: 5px;
28
+ }
29
+ .stSlider>div>div>div>div {
30
+ background-color: #4fd1c5;
31
+ }
32
+ </style>
33
+ """, unsafe_allow_html=True)
34
+
35
+ # Set device preference
36
+ USE_GPU = False # Set to True to use GPU, False to use CPU
37
+ device = torch.device('cuda' if USE_GPU and torch.cuda.is_available() else 'cpu')
38
+
39
+ # Available color schemes
40
+ COLOR_SCHEMES = {
41
+ 'Plasma': plt.cm.plasma,
42
+ 'Viridis': plt.cm.viridis,
43
+ 'Magma': plt.cm.magma,
44
+ 'Inferno': plt.cm.inferno,
45
+ 'Cividis': plt.cm.cividis,
46
+ 'Spectral': plt.cm.Spectral,
47
+ 'Coolwarm': plt.cm.coolwarm
48
+ }
49
+
50
+ # Load the pre-trained Vision Transformer model
51
+ @st.cache_resource
52
+ def load_model():
53
+ model_name = 'google/vit-base-patch16-384'
54
+ config = ViTConfig.from_pretrained(model_name, output_attentions=True, attn_implementation="eager")
55
+ model = ViTModel.from_pretrained(model_name, config=config)
56
+ model.eval()
57
+ return model.to(device)
58
+
59
+ model = load_model()
60
+
61
+ # Image preprocessing
62
+ preprocess = transforms.Compose([
63
+ transforms.Resize((384, 384)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
66
+ ])
67
+
68
+ def get_attention_map(img):
69
+ # Preprocess the image
70
+ input_tensor = preprocess(img).unsqueeze(0).to(device)
71
+
72
+ # Get model output
73
+ with torch.no_grad():
74
+ outputs = model(input_tensor, output_attentions=True)
75
+
76
+ # Process attention maps
77
+ att_mat = torch.stack(outputs.attentions).squeeze(1)
78
+ att_mat = torch.mean(att_mat, dim=1)
79
+
80
+ # Add residual connections
81
+ residual_att = torch.eye(att_mat.size(-1)).unsqueeze(0).to(device)
82
+ aug_att_mat = att_mat + residual_att
83
+ aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
84
+
85
+ # Recursively multiply the weight matrices
86
+ joint_attentions = torch.zeros(aug_att_mat.size()).to(device)
87
+ joint_attentions[0] = aug_att_mat[0]
88
+ for n in range(1, aug_att_mat.size(0)):
89
+ joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
90
+
91
+ # Get final attention map
92
+ v = joint_attentions[-1]
93
+ grid_size = int(np.sqrt(aug_att_mat.size(-1)))
94
+ mask = v[0, 1:].reshape(grid_size, grid_size).detach().cpu().numpy()
95
+
96
+ return mask
97
+
98
+ def overlay_attention_map(image, attention_map, overlay_strength, color_scheme):
99
+ # Resize attention map to match image size
100
+ attention_map = Image.fromarray(attention_map).resize(image.size, Image.BICUBIC)
101
+ attention_map = np.array(attention_map)
102
+
103
+ # Normalize attention map
104
+ attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
105
+
106
+ # Apply selected color map
107
+ attention_map_color = color_scheme(attention_map)
108
+
109
+ # Convert image to RGBA
110
+ image_rgba = image.convert("RGBA")
111
+ image_array = np.array(image_rgba) / 255.0
112
+
113
+ # Overlay attention map on image with adjustable strength
114
+ overlayed_image = image_array * (1 - overlay_strength) + attention_map_color * overlay_strength
115
+
116
+ return Image.fromarray((overlayed_image * 255).astype(np.uint8))
117
+
118
+ st.title("Where will they look?")
119
+
120
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
121
+
122
+ if uploaded_file is not None:
123
+ image = Image.open(uploaded_file).convert('RGB')
124
+
125
+ st.success("Starting Prediction Process...")
126
+ attention_map = get_attention_map(image)
127
+
128
+ col1, col2 = st.columns(2)
129
+
130
+ with col1:
131
+ overlay_strength = st.slider("Heatmap Overlay Percentage", 0, 100, 50) / 100.0
132
+
133
+ with col2:
134
+ color_scheme_name = st.selectbox("Choose Heatmap Color Scheme", list(COLOR_SCHEMES.keys()))
135
+
136
+ color_scheme = COLOR_SCHEMES[color_scheme_name]
137
+
138
+ overlayed_image = overlay_attention_map(image, attention_map, overlay_strength, color_scheme)
139
+
140
+ st.image(overlayed_image, caption='Image with Heatmap Overlay', use_column_width=True)
141
+
142
+ # Option to download the overlayed image
143
+ buf = io.BytesIO()
144
+ overlayed_image.save(buf, format="PNG")
145
+ btn = st.download_button(
146
+ label="Download Image with Attention Map",
147
+ data=buf.getvalue(),
148
+ file_name="attention_map_overlay.png",
149
+ mime="image/png"
150
+ )
151
+
152
+ streamlit_analytics.stop_tracking()