Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
@@ -24,8 +24,9 @@ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
|
|
24 |
"""Load a model from timm and prepare it for attention extraction."""
|
25 |
timm.layers.set_fused_attn(False)
|
26 |
model = create_model(model_name, pretrained=True)
|
|
|
27 |
model.eval()
|
28 |
-
extractor = AttentionExtract(model, method='fx')
|
29 |
return model, extractor
|
30 |
|
31 |
@spaces.GPU
|
@@ -46,8 +47,8 @@ def process_image(
|
|
46 |
is_training=False
|
47 |
)
|
48 |
|
49 |
-
# Preprocess the image
|
50 |
-
tensor = transform(image).unsqueeze(0).
|
51 |
|
52 |
# Extract attention maps
|
53 |
attention_maps = extractor(tensor)
|
@@ -67,12 +68,11 @@ def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, f
|
|
67 |
return masked_image.astype(np.uint8)
|
68 |
|
69 |
def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
|
70 |
-
|
71 |
-
result = torch.eye(attentions[0].size(-1)).to(
|
72 |
with torch.no_grad():
|
73 |
for attention in attentions:
|
74 |
if head_fusion.startswith('mean'):
|
75 |
-
# mean_std fusion doesn't appear to make sense with rollout
|
76 |
attention_heads_fused = attention.mean(dim=0)
|
77 |
elif head_fusion == "max":
|
78 |
attention_heads_fused = attention.amax(dim=0)
|
@@ -87,14 +87,13 @@ def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
|
|
87 |
indices = indices[indices >= num_prefix_tokens]
|
88 |
flat[indices] = 0
|
89 |
|
90 |
-
I = torch.eye(attention_heads_fused.size(-1)).to(
|
91 |
a = (attention_heads_fused + 1.0 * I) / 2
|
92 |
a = a / a.sum(dim=-1)
|
93 |
result = torch.matmul(a, result)
|
94 |
|
95 |
# Look at the total attention between the prefix tokens (usually class tokens)
|
96 |
# and the image patches
|
97 |
-
# FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc)
|
98 |
mask = result[0, num_prefix_tokens:]
|
99 |
width = int(mask.size(-1) ** 0.5)
|
100 |
mask = mask.reshape(width, width).cpu().numpy()
|
@@ -110,7 +109,6 @@ def visualize_attention(
|
|
110 |
) -> Tuple[List[Image.Image], Image.Image]:
|
111 |
"""Visualize attention maps and rollout for the given image and model."""
|
112 |
model, extractor = load_model(model_name)
|
113 |
-
model = model.to('cuda')
|
114 |
attention_maps = process_image(image, model, extractor)
|
115 |
|
116 |
num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
|
@@ -150,7 +148,7 @@ def visualize_attention(
|
|
150 |
# Interpolate to match image size
|
151 |
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
|
152 |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
|
153 |
-
attn_map = attn_map.squeeze().
|
154 |
|
155 |
# Normalize attention map
|
156 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
@@ -179,6 +177,9 @@ def visualize_attention(
|
|
179 |
visualizations.append(vis_image)
|
180 |
plt.close(fig)
|
181 |
|
|
|
|
|
|
|
182 |
# Calculate rollout
|
183 |
rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
|
184 |
|
@@ -209,7 +210,6 @@ def visualize_attention(
|
|
209 |
|
210 |
return visualizations, rollout_image
|
211 |
|
212 |
-
|
213 |
# Create Gradio interface
|
214 |
iface = gr.Interface(
|
215 |
fn=visualize_attention,
|
@@ -231,5 +231,5 @@ iface = gr.Interface(
|
|
231 |
description="Upload an image and select a timm model to visualize its attention maps."
|
232 |
)
|
233 |
|
234 |
-
# Launch the interface
|
235 |
-
iface.launch(
|
|
|
24 |
"""Load a model from timm and prepare it for attention extraction."""
|
25 |
timm.layers.set_fused_attn(False)
|
26 |
model = create_model(model_name, pretrained=True)
|
27 |
+
model = model.cuda() # Move the model to CUDA
|
28 |
model.eval()
|
29 |
+
extractor = AttentionExtract(model, method='fx')
|
30 |
return model, extractor
|
31 |
|
32 |
@spaces.GPU
|
|
|
47 |
is_training=False
|
48 |
)
|
49 |
|
50 |
+
# Preprocess the image and move to CUDA
|
51 |
+
tensor = transform(image).unsqueeze(0).cuda()
|
52 |
|
53 |
# Extract attention maps
|
54 |
attention_maps = extractor(tensor)
|
|
|
68 |
return masked_image.astype(np.uint8)
|
69 |
|
70 |
def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
|
71 |
+
device = attentions[0].device
|
72 |
+
result = torch.eye(attentions[0].size(-1)).to(device)
|
73 |
with torch.no_grad():
|
74 |
for attention in attentions:
|
75 |
if head_fusion.startswith('mean'):
|
|
|
76 |
attention_heads_fused = attention.mean(dim=0)
|
77 |
elif head_fusion == "max":
|
78 |
attention_heads_fused = attention.amax(dim=0)
|
|
|
87 |
indices = indices[indices >= num_prefix_tokens]
|
88 |
flat[indices] = 0
|
89 |
|
90 |
+
I = torch.eye(attention_heads_fused.size(-1)).to(device)
|
91 |
a = (attention_heads_fused + 1.0 * I) / 2
|
92 |
a = a / a.sum(dim=-1)
|
93 |
result = torch.matmul(a, result)
|
94 |
|
95 |
# Look at the total attention between the prefix tokens (usually class tokens)
|
96 |
# and the image patches
|
|
|
97 |
mask = result[0, num_prefix_tokens:]
|
98 |
width = int(mask.size(-1) ** 0.5)
|
99 |
mask = mask.reshape(width, width).cpu().numpy()
|
|
|
109 |
) -> Tuple[List[Image.Image], Image.Image]:
|
110 |
"""Visualize attention maps and rollout for the given image and model."""
|
111 |
model, extractor = load_model(model_name)
|
|
|
112 |
attention_maps = process_image(image, model, extractor)
|
113 |
|
114 |
num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
|
|
|
148 |
# Interpolate to match image size
|
149 |
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
|
150 |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
|
151 |
+
attn_map = attn_map.squeeze().cpu().numpy() # Move to CPU before converting to numpy
|
152 |
|
153 |
# Normalize attention map
|
154 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
|
|
177 |
visualizations.append(vis_image)
|
178 |
plt.close(fig)
|
179 |
|
180 |
+
# Ensure tensors are on CPU before converting to numpy
|
181 |
+
attentions_for_rollout = [attn.cpu() for attn in attentions_for_rollout]
|
182 |
+
|
183 |
# Calculate rollout
|
184 |
rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
|
185 |
|
|
|
210 |
|
211 |
return visualizations, rollout_image
|
212 |
|
|
|
213 |
# Create Gradio interface
|
214 |
iface = gr.Interface(
|
215 |
fn=visualize_attention,
|
|
|
231 |
description="Upload an image and select a timm model to visualize its attention maps."
|
232 |
)
|
233 |
|
234 |
+
# Launch the interface
|
235 |
+
iface.launch()
|