taesiri commited on
Commit
d97d4d2
·
1 Parent(s): dfc8148
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -113,7 +113,6 @@ def visualize_attention(
113
  model = model.to('cuda')
114
  attention_maps = process_image(image, model, extractor)
115
 
116
- # FIXME handle wider range of models that may not have num_prefix_tokens attr
117
  num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
118
 
119
  # Convert PIL Image to numpy array
@@ -142,7 +141,6 @@ def visualize_attention(
142
  raise ValueError(f"Invalid head fusion method: {head_fusion}")
143
 
144
  # Use the first token's attention (usually the class token)
145
- # FIXME handle different prefix token scenarios
146
  attn_map = attn_map[0]
147
 
148
  # Reshape the attention map to 2D
@@ -152,7 +150,7 @@ def visualize_attention(
152
  # Interpolate to match image size
153
  attn_map = attn_map.unsqueeze(0).unsqueeze(0)
154
  attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
155
- attn_map = attn_map.squeeze().cpu().numpy()
156
 
157
  # Normalize attention map
158
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
@@ -228,4 +226,5 @@ iface = gr.Interface(
228
  description="Upload an image and select a timm model to visualize its attention maps."
229
  )
230
 
231
- iface.launch()
 
 
113
  model = model.to('cuda')
114
  attention_maps = process_image(image, model, extractor)
115
 
 
116
  num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
117
 
118
  # Convert PIL Image to numpy array
 
141
  raise ValueError(f"Invalid head fusion method: {head_fusion}")
142
 
143
  # Use the first token's attention (usually the class token)
 
144
  attn_map = attn_map[0]
145
 
146
  # Reshape the attention map to 2D
 
150
  # Interpolate to match image size
151
  attn_map = attn_map.unsqueeze(0).unsqueeze(0)
152
  attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
153
+ attn_map = attn_map.squeeze().detach().cpu().numpy() # Detach before converting to numpy
154
 
155
  # Normalize attention map
156
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
 
226
  description="Upload an image and select a timm model to visualize its attention maps."
227
  )
228
 
229
+ # Launch the interface with share=True to create a public link
230
+ iface.launch(share=True)