Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
@@ -113,7 +113,6 @@ def visualize_attention(
|
|
113 |
model = model.to('cuda')
|
114 |
attention_maps = process_image(image, model, extractor)
|
115 |
|
116 |
-
# FIXME handle wider range of models that may not have num_prefix_tokens attr
|
117 |
num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
|
118 |
|
119 |
# Convert PIL Image to numpy array
|
@@ -142,7 +141,6 @@ def visualize_attention(
|
|
142 |
raise ValueError(f"Invalid head fusion method: {head_fusion}")
|
143 |
|
144 |
# Use the first token's attention (usually the class token)
|
145 |
-
# FIXME handle different prefix token scenarios
|
146 |
attn_map = attn_map[0]
|
147 |
|
148 |
# Reshape the attention map to 2D
|
@@ -152,7 +150,7 @@ def visualize_attention(
|
|
152 |
# Interpolate to match image size
|
153 |
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
|
154 |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
|
155 |
-
attn_map = attn_map.squeeze().cpu().numpy()
|
156 |
|
157 |
# Normalize attention map
|
158 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
@@ -228,4 +226,5 @@ iface = gr.Interface(
|
|
228 |
description="Upload an image and select a timm model to visualize its attention maps."
|
229 |
)
|
230 |
|
231 |
-
|
|
|
|
113 |
model = model.to('cuda')
|
114 |
attention_maps = process_image(image, model, extractor)
|
115 |
|
|
|
116 |
num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
|
117 |
|
118 |
# Convert PIL Image to numpy array
|
|
|
141 |
raise ValueError(f"Invalid head fusion method: {head_fusion}")
|
142 |
|
143 |
# Use the first token's attention (usually the class token)
|
|
|
144 |
attn_map = attn_map[0]
|
145 |
|
146 |
# Reshape the attention map to 2D
|
|
|
150 |
# Interpolate to match image size
|
151 |
attn_map = attn_map.unsqueeze(0).unsqueeze(0)
|
152 |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
|
153 |
+
attn_map = attn_map.squeeze().detach().cpu().numpy() # Detach before converting to numpy
|
154 |
|
155 |
# Normalize attention map
|
156 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
|
|
|
226 |
description="Upload an image and select a timm model to visualize its attention maps."
|
227 |
)
|
228 |
|
229 |
+
# Launch the interface with share=True to create a public link
|
230 |
+
iface.launch(share=True)
|