rwightman HF staff commited on
Commit
500c37b
·
verified ·
1 Parent(s): 19f8d94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -12
app.py CHANGED
@@ -9,9 +9,10 @@ from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, Ablat
9
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
11
  from timm.data import create_transform
 
12
 
13
  # List of available timm models
14
- MODELS = timm.list_models()
15
 
16
  # List of available GradCAM methods
17
  CAM_METHODS = {
@@ -25,6 +26,16 @@ CAM_METHODS = {
25
  "FullGrad": FullGrad
26
  }
27
 
 
 
 
 
 
 
 
 
 
 
28
  def load_model(model_name):
29
  model = timm.create_model(model_name, pretrained=True)
30
  model.eval()
@@ -50,9 +61,14 @@ def process_image(image_path, model):
50
  tensor = transform(image).unsqueeze(0)
51
  return tensor
52
 
53
- def get_cam_image(model, image, target_layer, cam_method):
 
 
 
 
 
54
  cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
55
- grayscale_cam = cam(input_tensor=image)
56
 
57
  config = model.pretrained_cfg
58
  mean = torch.tensor(config['mean']).view(3, 1, 1)
@@ -79,20 +95,35 @@ def get_target_layer(model, target_layer_name):
79
  print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
80
  return None
81
 
82
- def explain_image(model_name, image_path, cam_method, feature_module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  model = load_model(model_name)
84
  image = process_image(image_path, model)
85
 
86
  target_layer = get_target_layer(model, feature_module)
87
 
88
  if target_layer is None:
89
- # Fallback to the last feature module or last convolutional layer
90
  feature_info = get_feature_info(model)
91
  if feature_info:
92
  target_layer = get_target_layer(model, feature_info[-1])
93
  print(f"Using last feature module: {feature_info[-1]}")
94
  else:
95
- # Fallback to finding last convolutional layer
96
  for name, module in reversed(list(model.named_modules())):
97
  if isinstance(module, torch.nn.Conv2d):
98
  target_layer = module
@@ -102,17 +133,35 @@ def explain_image(model_name, image_path, cam_method, feature_module):
102
  if target_layer is None:
103
  raise ValueError("Could not find a suitable target layer.")
104
 
105
- cam_image = get_cam_image(model, image, target_layer, cam_method)
106
- return cam_image
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def update_feature_modules(model_name):
109
  model = load_model(model_name)
110
  feature_modules = get_feature_info(model)
111
  return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
112
 
 
 
 
 
 
 
113
  with gr.Blocks() as demo:
114
- gr.Markdown("# Explainable AI with timm models")
115
- gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.")
116
 
117
  with gr.Row():
118
  with gr.Column():
@@ -120,17 +169,20 @@ with gr.Blocks() as demo:
120
  image_input = gr.Image(type="filepath", label="Upload Image")
121
  cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
122
  feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
 
123
  explain_button = gr.Button("Explain Image")
124
 
125
  with gr.Column():
126
  output_image = gr.Image(type="pil", label="Explained Image")
 
127
 
128
  model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
 
129
 
130
  explain_button.click(
131
  fn=explain_image,
132
- inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown],
133
- outputs=[output_image]
134
  )
135
 
136
  demo.launch()
 
9
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10
  from pytorch_grad_cam.utils.image import show_cam_on_image
11
  from timm.data import create_transform
12
+ from timm.data import infer_imagenet_subset, ImageNetInfo
13
 
14
  # List of available timm models
15
+ MODELS = timm.list_pretrained()
16
 
17
  # List of available GradCAM methods
18
  CAM_METHODS = {
 
26
  "FullGrad": FullGrad
27
  }
28
 
29
+ class CustomDatasetInfo:
30
+ def __init__(self, label_names, label_descriptions=None):
31
+ self.label_names = label_names
32
+ self.label_descriptions = label_descriptions or label_names
33
+
34
+ def index_to_description(self, index, detailed=False):
35
+ if detailed and self.label_descriptions:
36
+ return self.label_descriptions[index]
37
+ return self.label_names[index]
38
+
39
  def load_model(model_name):
40
  model = timm.create_model(model_name, pretrained=True)
41
  model.eval()
 
61
  tensor = transform(image).unsqueeze(0)
62
  return tensor
63
 
64
+ def get_cam_image(model, image, target_layer, cam_method, target_class):
65
+ if target_class is not None and target_class != "highest scoring":
66
+ target = ClassifierOutputTarget(target_class)
67
+ else:
68
+ target = None
69
+
70
  cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
71
+ grayscale_cam = cam(input_tensor=image, targets=[target] if target else None)
72
 
73
  config = model.pretrained_cfg
74
  mean = torch.tensor(config['mean']).view(3, 1, 1)
 
95
  print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
96
  return None
97
 
98
+ def get_class_names(model):
99
+ dataset_info = None
100
+ label_names = model.pretrained_cfg.get("label_names", None)
101
+ label_descriptions = model.pretrained_cfg.get("label_descriptions", None)
102
+ if label_names is None:
103
+ imagenet_subset = infer_imagenet_subset(model)
104
+ if imagenet_subset:
105
+ dataset_info = ImageNetInfo(imagenet_subset)
106
+ else:
107
+ label_names = [f"LABEL_{i}" for i in range(model.num_classes)]
108
+ if dataset_info is None:
109
+ dataset_info = CustomDatasetInfo(
110
+ label_names=label_names,
111
+ label_descriptions=label_descriptions,
112
+ )
113
+ return dataset_info
114
+
115
+ def explain_image(model_name, image_path, cam_method, feature_module, target_class):
116
  model = load_model(model_name)
117
  image = process_image(image_path, model)
118
 
119
  target_layer = get_target_layer(model, feature_module)
120
 
121
  if target_layer is None:
 
122
  feature_info = get_feature_info(model)
123
  if feature_info:
124
  target_layer = get_target_layer(model, feature_info[-1])
125
  print(f"Using last feature module: {feature_info[-1]}")
126
  else:
 
127
  for name, module in reversed(list(model.named_modules())):
128
  if isinstance(module, torch.nn.Conv2d):
129
  target_layer = module
 
133
  if target_layer is None:
134
  raise ValueError("Could not find a suitable target layer.")
135
 
136
+ target_class_index = None if target_class == "highest scoring" else int(target_class.split(':')[0])
137
+ cam_image = get_cam_image(model, image, target_layer, cam_method, target_class_index)
138
+
139
+ with torch.no_grad():
140
+ out = model(image)
141
+ probabilities = out.squeeze(0).softmax(dim=0)
142
+ values, indices = torch.topk(probabilities, 5) # Top 5 predictions
143
+ dataset_info = get_class_names(model)
144
+ labels = [
145
+ f"{i}: {dataset_info.index_to_description(i.item(), detailed=True)} ({v.item():.2%})"
146
+ for i, v in zip(indices, values)
147
+ ]
148
+
149
+ return cam_image, "\n".join(labels)
150
 
151
  def update_feature_modules(model_name):
152
  model = load_model(model_name)
153
  feature_modules = get_feature_info(model)
154
  return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
155
 
156
+ def update_class_dropdown(model_name):
157
+ model = load_model(model_name)
158
+ dataset_info = get_class_names(model)
159
+ class_names = ["highest scoring"] + [f"{i}: {dataset_info.index_to_description(i, detailed=True)}" for i in range(model.num_classes)]
160
+ return gr.Dropdown(choices=class_names, value="highest scoring")
161
+
162
  with gr.Blocks() as demo:
163
+ gr.Markdown("# Explainable AI with timm models. NOTE: This is a WIP but some models are functioning.")
164
+ gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module and target class to visualize the explanation.")
165
 
166
  with gr.Row():
167
  with gr.Column():
 
169
  image_input = gr.Image(type="filepath", label="Upload Image")
170
  cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
171
  feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
172
+ class_dropdown = gr.Dropdown(label="Select Target Class (optional)")
173
  explain_button = gr.Button("Explain Image")
174
 
175
  with gr.Column():
176
  output_image = gr.Image(type="pil", label="Explained Image")
177
+ prediction_text = gr.Textbox(label="Top 5 Predictions")
178
 
179
  model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
180
+ model_dropdown.change(fn=update_class_dropdown, inputs=[model_dropdown], outputs=[class_dropdown])
181
 
182
  explain_button.click(
183
  fn=explain_image,
184
+ inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown, class_dropdown],
185
+ outputs=[output_image, prediction_text]
186
  )
187
 
188
  demo.launch()