taesiri commited on
Commit
980c76b
·
1 Parent(s): ac69117
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -24,8 +24,9 @@ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
24
  """Load a model from timm and prepare it for attention extraction."""
25
  timm.layers.set_fused_attn(False)
26
  model = create_model(model_name, pretrained=True)
 
27
  model.eval()
28
- extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
29
  return model, extractor
30
 
31
  @spaces.GPU
@@ -46,8 +47,8 @@ def process_image(
46
  is_training=False
47
  )
48
 
49
- # Preprocess the image
50
- tensor = transform(image).unsqueeze(0).to('cuda')
51
 
52
  # Extract attention maps
53
  attention_maps = extractor(tensor)
@@ -67,12 +68,11 @@ def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, f
67
  return masked_image.astype(np.uint8)
68
 
69
  def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
70
- # based on https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
71
- result = torch.eye(attentions[0].size(-1)).to(attentions[0].device)
72
  with torch.no_grad():
73
  for attention in attentions:
74
  if head_fusion.startswith('mean'):
75
- # mean_std fusion doesn't appear to make sense with rollout
76
  attention_heads_fused = attention.mean(dim=0)
77
  elif head_fusion == "max":
78
  attention_heads_fused = attention.amax(dim=0)
@@ -87,14 +87,13 @@ def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
87
  indices = indices[indices >= num_prefix_tokens]
88
  flat[indices] = 0
89
 
90
- I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
91
  a = (attention_heads_fused + 1.0 * I) / 2
92
  a = a / a.sum(dim=-1)
93
  result = torch.matmul(a, result)
94
 
95
  # Look at the total attention between the prefix tokens (usually class tokens)
96
  # and the image patches
97
- # FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc)
98
  mask = result[0, num_prefix_tokens:]
99
  width = int(mask.size(-1) ** 0.5)
100
  mask = mask.reshape(width, width).cpu().numpy()
@@ -110,7 +109,6 @@ def visualize_attention(
110
  ) -> Tuple[List[Image.Image], Image.Image]:
111
  """Visualize attention maps and rollout for the given image and model."""
112
  model, extractor = load_model(model_name)
113
- model = model.to('cuda')
114
  attention_maps = process_image(image, model, extractor)
115
 
116
  num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
@@ -150,7 +148,7 @@ def visualize_attention(
150
  # Interpolate to match image size
151
  attn_map = attn_map.unsqueeze(0).unsqueeze(0)
152
  attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
153
- attn_map = attn_map.squeeze().detach().cpu().numpy() # Detach before converting to numpy
154
 
155
  # Normalize attention map
156
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
@@ -179,6 +177,9 @@ def visualize_attention(
179
  visualizations.append(vis_image)
180
  plt.close(fig)
181
 
 
 
 
182
  # Calculate rollout
183
  rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
184
 
@@ -209,7 +210,6 @@ def visualize_attention(
209
 
210
  return visualizations, rollout_image
211
 
212
-
213
  # Create Gradio interface
214
  iface = gr.Interface(
215
  fn=visualize_attention,
@@ -231,5 +231,5 @@ iface = gr.Interface(
231
  description="Upload an image and select a timm model to visualize its attention maps."
232
  )
233
 
234
- # Launch the interface with share=True to create a public link
235
- iface.launch(share=True)
 
24
  """Load a model from timm and prepare it for attention extraction."""
25
  timm.layers.set_fused_attn(False)
26
  model = create_model(model_name, pretrained=True)
27
+ model = model.cuda() # Move the model to CUDA
28
  model.eval()
29
+ extractor = AttentionExtract(model, method='fx')
30
  return model, extractor
31
 
32
  @spaces.GPU
 
47
  is_training=False
48
  )
49
 
50
+ # Preprocess the image and move to CUDA
51
+ tensor = transform(image).unsqueeze(0).cuda()
52
 
53
  # Extract attention maps
54
  attention_maps = extractor(tensor)
 
68
  return masked_image.astype(np.uint8)
69
 
70
  def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
71
+ device = attentions[0].device
72
+ result = torch.eye(attentions[0].size(-1)).to(device)
73
  with torch.no_grad():
74
  for attention in attentions:
75
  if head_fusion.startswith('mean'):
 
76
  attention_heads_fused = attention.mean(dim=0)
77
  elif head_fusion == "max":
78
  attention_heads_fused = attention.amax(dim=0)
 
87
  indices = indices[indices >= num_prefix_tokens]
88
  flat[indices] = 0
89
 
90
+ I = torch.eye(attention_heads_fused.size(-1)).to(device)
91
  a = (attention_heads_fused + 1.0 * I) / 2
92
  a = a / a.sum(dim=-1)
93
  result = torch.matmul(a, result)
94
 
95
  # Look at the total attention between the prefix tokens (usually class tokens)
96
  # and the image patches
 
97
  mask = result[0, num_prefix_tokens:]
98
  width = int(mask.size(-1) ** 0.5)
99
  mask = mask.reshape(width, width).cpu().numpy()
 
109
  ) -> Tuple[List[Image.Image], Image.Image]:
110
  """Visualize attention maps and rollout for the given image and model."""
111
  model, extractor = load_model(model_name)
 
112
  attention_maps = process_image(image, model, extractor)
113
 
114
  num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
 
148
  # Interpolate to match image size
149
  attn_map = attn_map.unsqueeze(0).unsqueeze(0)
150
  attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
151
+ attn_map = attn_map.squeeze().cpu().numpy() # Move to CPU before converting to numpy
152
 
153
  # Normalize attention map
154
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
 
177
  visualizations.append(vis_image)
178
  plt.close(fig)
179
 
180
+ # Ensure tensors are on CPU before converting to numpy
181
+ attentions_for_rollout = [attn.cpu() for attn in attentions_for_rollout]
182
+
183
  # Calculate rollout
184
  rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
185
 
 
210
 
211
  return visualizations, rollout_image
212
 
 
213
  # Create Gradio interface
214
  iface = gr.Interface(
215
  fn=visualize_attention,
 
231
  description="Upload an image and select a timm model to visualize its attention maps."
232
  )
233
 
234
+ # Launch the interface
235
+ iface.launch()